You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ni...@apache.org on 2016/07/20 16:56:48 UTC
[2/3] samza git commit: SAMZA-863: Multithreading support in Samza
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index 18c0922..b8600d5 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -20,12 +20,19 @@
package org.apache.samza.container
import java.io.File
+import java.lang.Thread.UncaughtExceptionHandler
+import java.net.URL
+import java.net.UnknownHostException
import java.nio.file.Path
import java.util
-import java.lang.Thread.UncaughtExceptionHandler
-import java.net.{URL, UnknownHostException}
+import java.util.concurrent.ExecutorService
+import java.util.concurrent.Executors
+import java.util.concurrent.TimeUnit
+
import org.apache.samza.SamzaException
-import org.apache.samza.checkpoint.{CheckpointManagerFactory, OffsetManager, OffsetManagerMetrics}
+import org.apache.samza.checkpoint.CheckpointManagerFactory
+import org.apache.samza.checkpoint.OffsetManager
+import org.apache.samza.checkpoint.OffsetManagerMetrics
import org.apache.samza.config.JobConfig.Config2Job
import org.apache.samza.config.MetricsConfig.Config2Metrics
import org.apache.samza.config.SerializerConfig.Config2Serializer
@@ -34,18 +41,45 @@ import org.apache.samza.config.StorageConfig.Config2Storage
import org.apache.samza.config.StreamConfig.Config2Stream
import org.apache.samza.config.SystemConfig.Config2System
import org.apache.samza.config.TaskConfig.Config2Task
+import org.apache.samza.container.disk.DiskQuotaPolicyFactory
+import org.apache.samza.container.disk.DiskSpaceMonitor
import org.apache.samza.container.disk.DiskSpaceMonitor.Listener
-import org.apache.samza.container.disk.{NoThrottlingDiskQuotaPolicyFactory, DiskQuotaPolicyFactory, PollingScanDiskSpaceMonitor, DiskSpaceMonitor}
+import org.apache.samza.container.disk.NoThrottlingDiskQuotaPolicyFactory
+import org.apache.samza.container.disk.PollingScanDiskSpaceMonitor
import org.apache.samza.coordinator.stream.CoordinatorStreamSystemFactory
-import org.apache.samza.job.model.{ContainerModel, JobModel}
-import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistryMap, MetricsReporter, MetricsReporterFactory}
-import org.apache.samza.serializers.{SerdeFactory, SerdeManager}
+import org.apache.samza.job.model.ContainerModel
+import org.apache.samza.job.model.JobModel
+import org.apache.samza.metrics.JmxServer
+import org.apache.samza.metrics.JvmMetrics
+import org.apache.samza.metrics.MetricsRegistryMap
+import org.apache.samza.metrics.MetricsReporter
+import org.apache.samza.metrics.MetricsReporterFactory
+import org.apache.samza.serializers.SerdeFactory
+import org.apache.samza.serializers.SerdeManager
import org.apache.samza.serializers.model.SamzaObjectMapper
-import org.apache.samza.storage.{StorageEngineFactory, TaskStorageManager}
-import org.apache.samza.system.{StreamMetadataCache, SystemConsumers, SystemConsumersMetrics, SystemFactory, SystemProducers, SystemProducersMetrics, SystemStream, SystemStreamPartition}
-import org.apache.samza.system.chooser.{DefaultChooser, MessageChooserFactory, RoundRobinChooserFactory}
-import org.apache.samza.task.{StreamTask, TaskInstanceCollector}
-import org.apache.samza.util.{ThrottlingExecutor, ExponentialSleepStrategy, Logging, Util}
+import org.apache.samza.storage.StorageEngineFactory
+import org.apache.samza.storage.TaskStorageManager
+import org.apache.samza.system.StreamMetadataCache
+import org.apache.samza.system.SystemConsumers
+import org.apache.samza.system.SystemConsumersMetrics
+import org.apache.samza.system.SystemFactory
+import org.apache.samza.system.SystemProducers
+import org.apache.samza.system.SystemProducersMetrics
+import org.apache.samza.system.SystemStream
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.system.chooser.DefaultChooser
+import org.apache.samza.system.chooser.MessageChooserFactory
+import org.apache.samza.system.chooser.RoundRobinChooserFactory
+import org.apache.samza.task.AsyncRunLoop
+import org.apache.samza.task.AsyncStreamTask
+import org.apache.samza.task.AsyncStreamTaskAdapter
+import org.apache.samza.task.StreamTask
+import org.apache.samza.task.TaskInstanceCollector
+import org.apache.samza.util.ExponentialSleepStrategy
+import org.apache.samza.util.Logging
+import org.apache.samza.util.ThrottlingExecutor
+import org.apache.samza.util.Util
+
import scala.collection.JavaConversions._
object SamzaContainer extends Logging {
@@ -164,6 +198,12 @@ object SamzaContainer extends Logging {
info("Got input stream metadata: %s" format inputStreamMetadata)
+ val taskClassName = config
+ .getTaskClass
+ .getOrElse(throw new SamzaException("No task class defined in configuration."))
+
+ info("Got stream task class: %s" format taskClassName)
+
val consumers = inputSystems
.map(systemName => {
val systemFactory = systemFactories(systemName)
@@ -181,6 +221,9 @@ object SamzaContainer extends Logging {
info("Got system consumers: %s" format consumers.keys)
+ val isAsyncTask = classOf[AsyncStreamTask].isAssignableFrom(Class.forName(taskClassName))
+ info("%s is AsyncStreamTask" format taskClassName)
+
val producers = systemFactories
.map {
case (systemName, systemFactory) =>
@@ -360,26 +403,22 @@ object SamzaContainer extends Logging {
info("Got storage engines: %s" format storageEngineFactories.keys)
- val taskClassName = config
- .getTaskClass
- .getOrElse(throw new SamzaException("No task class defined in configuration."))
-
- info("Got stream task class: %s" format taskClassName)
-
- val taskWindowMs = config.getWindowMs.getOrElse(-1L)
-
- info("Got window milliseconds: %s" format taskWindowMs)
+ val singleThreadMode = config.getSingleThreadMode
+ info("Got single thread mode: " + singleThreadMode)
- val taskCommitMs = config.getCommitMs.getOrElse(60000L)
-
- info("Got commit milliseconds: %s" format taskCommitMs)
+ if(singleThreadMode && isAsyncTask) {
+ throw new SamzaException("AsyncStreamTask %s cannot run on single thread mode." format taskClassName)
+ }
- val taskShutdownMs = config.getShutdownMs.getOrElse(5000L)
+ val threadPoolSize = config.getThreadPoolSize
+ info("Got thread pool size: " + threadPoolSize)
- info("Got shutdown timeout milliseconds: %s" format taskShutdownMs)
+ val taskThreadPool = if (!singleThreadMode && threadPoolSize > 0)
+ Executors.newFixedThreadPool(threadPoolSize)
+ else
+ null
// Wire up all task-instance-level (unshared) objects.
-
val taskNames = containerModel
.getTasks
.values
@@ -395,12 +434,18 @@ object SamzaContainer extends Logging {
val storeWatchPaths = new util.HashSet[Path]()
storeWatchPaths.add(defaultStoreBaseDir.toPath)
- val taskInstances: Map[TaskName, TaskInstance] = containerModel.getTasks.values.map(taskModel => {
+ val taskInstances: Map[TaskName, TaskInstance[_]] = containerModel.getTasks.values.map(taskModel => {
debug("Setting up task instance: %s" format taskModel)
val taskName = taskModel.getTaskName
- val task = Util.getObj[StreamTask](taskClassName)
+ val taskObj = Class.forName(taskClassName).newInstance
+
+ val task = if (!singleThreadMode && !isAsyncTask)
+ // Wrap the StreamTask into a AsyncStreamTask with the build-in thread pool
+ new AsyncStreamTaskAdapter(taskObj.asInstanceOf[StreamTask], taskThreadPool)
+ else
+ taskObj
val taskInstanceMetrics = new TaskInstanceMetrics("TaskName-%s" format taskName)
@@ -487,20 +532,22 @@ object SamzaContainer extends Logging {
info("Retrieved SystemStreamPartitions " + systemStreamPartitions + " for " + taskName)
- val taskInstance = new TaskInstance(
- task = task,
- taskName = taskName,
- config = config,
- metrics = taskInstanceMetrics,
- systemAdmins = systemAdmins,
- consumerMultiplexer = consumerMultiplexer,
- collector = collector,
- containerContext = containerContext,
- offsetManager = offsetManager,
- storageManager = storageManager,
- reporters = reporters,
- systemStreamPartitions = systemStreamPartitions,
- exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config))
+ def createTaskInstance[T] (task: T ): TaskInstance[T] = new TaskInstance[T](
+ task = task,
+ taskName = taskName,
+ config = config,
+ metrics = taskInstanceMetrics,
+ systemAdmins = systemAdmins,
+ consumerMultiplexer = consumerMultiplexer,
+ collector = collector,
+ containerContext = containerContext,
+ offsetManager = offsetManager,
+ storageManager = storageManager,
+ reporters = reporters,
+ systemStreamPartitions = systemStreamPartitions,
+ exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config))
+
+ val taskInstance = createTaskInstance(task)
(taskName, taskInstance)
}).toMap
@@ -533,14 +580,13 @@ object SamzaContainer extends Logging {
info(s"Disk quotas disabled because polling interval is not set ($DISK_POLL_INTERVAL_KEY)")
}
- val runLoop = new RunLoop(
- taskInstances = taskInstances,
- consumerMultiplexer = consumerMultiplexer,
- metrics = samzaContainerMetrics,
- windowMs = taskWindowMs,
- commitMs = taskCommitMs,
- shutdownMs = taskShutdownMs,
- executor = executor)
+ val runLoop = RunLoopFactory.createRunLoop(
+ taskInstances,
+ consumerMultiplexer,
+ taskThreadPool,
+ executor,
+ samzaContainerMetrics,
+ config)
info("Samza container setup complete.")
@@ -557,14 +603,15 @@ object SamzaContainer extends Logging {
reporters = reporters,
jvm = jvm,
jmxServer = jmxServer,
- diskSpaceMonitor = diskSpaceMonitor)
+ diskSpaceMonitor = diskSpaceMonitor,
+ taskThreadPool = taskThreadPool)
}
}
class SamzaContainer(
containerContext: SamzaContainerContext,
- taskInstances: Map[TaskName, TaskInstance],
- runLoop: RunLoop,
+ taskInstances: Map[TaskName, TaskInstance[_]],
+ runLoop: Runnable,
consumerMultiplexer: SystemConsumers,
producerMultiplexer: SystemProducers,
metrics: SamzaContainerMetrics,
@@ -574,7 +621,10 @@ class SamzaContainer(
localityManager: LocalityManager = null,
securityManager: SecurityManager = null,
reporters: Map[String, MetricsReporter] = Map(),
- jvm: JvmMetrics = null) extends Runnable with Logging {
+ jvm: JvmMetrics = null,
+ taskThreadPool: ExecutorService = null) extends Runnable with Logging {
+
+ val shutdownMs = containerContext.config.getShutdownMs.getOrElse(5000L)
def run {
try {
@@ -591,6 +641,7 @@ class SamzaContainer(
startSecurityManger
info("Entering run loop.")
+ addShutdownHook
runLoop.run
} catch {
case e: Exception =>
@@ -710,7 +761,7 @@ class SamzaContainer(
consumerMultiplexer.start
}
- def startSecurityManger: Unit = {
+ def startSecurityManger {
if (securityManager != null) {
info("Starting security manager.")
@@ -718,6 +769,25 @@ class SamzaContainer(
}
}
+ def addShutdownHook {
+ val runLoopThread = Thread.currentThread()
+ Runtime.getRuntime().addShutdownHook(new Thread() {
+ override def run() = {
+ info("Shutting down, will wait up to %s ms" format shutdownMs)
+ runLoop match {
+ case runLoop: RunLoop => runLoop.shutdown
+ case asyncRunLoop: AsyncRunLoop => asyncRunLoop.shutdown()
+ }
+ runLoopThread.join(shutdownMs)
+ if (runLoopThread.isAlive) {
+ warn("Did not shut down within %s ms, exiting" format shutdownMs)
+ } else {
+ info("Shutdown complete")
+ }
+ }
+ })
+ }
+
def shutdownConsumers {
info("Shutting down consumer multiplexer.")
@@ -733,6 +803,19 @@ class SamzaContainer(
def shutdownTask {
info("Shutting down task instance stream tasks.")
+
+ if (taskThreadPool != null) {
+ info("Shutting down task thread pool")
+ try {
+ taskThreadPool.shutdown()
+ if(taskThreadPool.awaitTermination(shutdownMs, TimeUnit.MILLISECONDS)) {
+ taskThreadPool.shutdownNow()
+ }
+ } catch {
+ case e: Exception => error(e.getMessage, e)
+ }
+ }
+
taskInstances.values.foreach(_.shutdownTask)
}
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
index 2044ce0..e3891cf 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
@@ -34,9 +34,11 @@ class SamzaContainerMetrics(
val envelopes = newCounter("process-envelopes")
val nullEnvelopes = newCounter("process-null-envelopes")
val chooseNs = newTimer("choose-ns")
+ val chooserUpdateNs = newTimer("chooser-update-ns")
val windowNs = newTimer("window-ns")
val processNs = newTimer("process-ns")
val commitNs = newTimer("commit-ns")
+ val blockNs = newTimer("block-ns")
val utilization = newGauge("event-loop-utilization", 0.0F)
val diskUsageBytes = newGauge("disk-usage-bytes", 0L)
val diskQuotaBytes = newGauge("disk-quota-bytes", Long.MaxValue)
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index d32a929..89f6857 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -19,36 +19,39 @@
package org.apache.samza.container
+
import org.apache.samza.SamzaException
import org.apache.samza.checkpoint.OffsetManager
import org.apache.samza.config.Config
-import org.apache.samza.config.TaskConfig.Config2Task
import org.apache.samza.metrics.MetricsReporter
import org.apache.samza.storage.TaskStorageManager
import org.apache.samza.system.IncomingMessageEnvelope
-import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.system.SystemAdmin
import org.apache.samza.system.SystemConsumers
-import org.apache.samza.task.TaskContext
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.task.AsyncStreamTask
import org.apache.samza.task.ClosableTask
import org.apache.samza.task.InitableTask
-import org.apache.samza.task.WindowableTask
-import org.apache.samza.task.StreamTask
import org.apache.samza.task.ReadableCoordinator
+import org.apache.samza.task.StreamTask
+import org.apache.samza.task.TaskCallbackFactory
+import org.apache.samza.task.TaskContext
import org.apache.samza.task.TaskInstanceCollector
+import org.apache.samza.task.WindowableTask
import org.apache.samza.util.Logging
+
import scala.collection.JavaConversions._
-import org.apache.samza.system.SystemAdmin
-class TaskInstance(
- task: StreamTask,
+class TaskInstance[T](
+ task: T,
val taskName: TaskName,
config: Config,
- metrics: TaskInstanceMetrics,
+ val metrics: TaskInstanceMetrics,
systemAdmins: Map[String, SystemAdmin],
consumerMultiplexer: SystemConsumers,
collector: TaskInstanceCollector,
containerContext: SamzaContainerContext,
- offsetManager: OffsetManager = new OffsetManager,
+ val offsetManager: OffsetManager = new OffsetManager,
storageManager: TaskStorageManager = null,
reporters: Map[String, MetricsReporter] = Map(),
val systemStreamPartitions: Set[SystemStreamPartition] = Set(),
@@ -56,6 +59,8 @@ class TaskInstance(
val isInitableTask = task.isInstanceOf[InitableTask]
val isWindowableTask = task.isInstanceOf[WindowableTask]
val isClosableTask = task.isInstanceOf[ClosableTask]
+ val isAsyncTask = task.isInstanceOf[AsyncStreamTask]
+
val context = new TaskContext {
def getMetricsRegistry = metrics.registry
def getSystemStreamPartitions = systemStreamPartitions
@@ -133,7 +138,7 @@ class TaskInstance(
})
}
- def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator) {
+ def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator, callbackFactory: TaskCallbackFactory = null) {
metrics.processes.inc
if (!ssp2catchedupMapping.getOrElse(envelope.getSystemStreamPartition,
@@ -146,13 +151,20 @@ class TaskInstance(
trace("Processing incoming message envelope for taskName and SSP: %s, %s" format (taskName, envelope.getSystemStreamPartition))
- exceptionHandler.maybeHandle {
- task.process(envelope, collector, coordinator)
- }
+ if (isAsyncTask) {
+ exceptionHandler.maybeHandle {
+ val callback = callbackFactory.createCallback()
+ task.asInstanceOf[AsyncStreamTask].processAsync(envelope, collector, coordinator, callback)
+ }
+ } else {
+ exceptionHandler.maybeHandle {
+ task.asInstanceOf[StreamTask].process(envelope, collector, coordinator)
+ }
- trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" format (taskName, envelope.getSystemStreamPartition, envelope.getOffset))
+ trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" format (taskName, envelope.getSystemStreamPartition, envelope.getOffset))
- offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset)
+ offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset)
+ }
}
}
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
index 8b86388..7bedadf 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
@@ -35,6 +35,8 @@ class TaskInstanceMetrics(
val sends = newCounter("send-calls")
val flushes = newCounter("flush-calls")
val messagesSent = newCounter("messages-sent")
+ val pendingMessages = newGauge("pending-messages", 0)
+ val messagesInFlight = newGauge("messages-in-flight", 0)
def addOffsetGauge(systemStreamPartition: SystemStreamPartition, getValue: () => String) {
newGauge("%s-%s-%d-offset" format (systemStreamPartition.getSystem, systemStreamPartition.getStream, systemStreamPartition.getPartition.getPartitionId), getValue)
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala b/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
index d3bd9b7..ba38b5c 100644
--- a/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
+++ b/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
@@ -71,9 +71,12 @@ object JobModelManager extends Logging {
coordinatorSystemConsumer.start
debug("Bootstrapping coordinator system stream.")
coordinatorSystemConsumer.bootstrap
+ val source = "Job-coordinator"
+ coordinatorSystemProducer.register(source)
+ info("Registering coordinator system stream producer.")
val config = coordinatorSystemConsumer.getConfig
info("Got config: %s" format config)
- val changelogManager = new ChangelogPartitionManager(coordinatorSystemProducer, coordinatorSystemConsumer, "Job-coordinator")
+ val changelogManager = new ChangelogPartitionManager(coordinatorSystemProducer, coordinatorSystemConsumer, source)
val localityManager = new LocalityManager(coordinatorSystemProducer, coordinatorSystemConsumer)
val systemNames = getSystemNames(config)
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
index 2efe836..a8355b9 100644
--- a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
+++ b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
@@ -99,7 +99,7 @@ class SystemConsumers (
* with no remaining unprocessed messages, the SystemConsumers will poll for
* it within 50ms of its availability in the stream system.</p>
*/
- pollIntervalMs: Int,
+ val pollIntervalMs: Int,
/**
* Clock can be used to inject a custom clock when mocking this class in
@@ -203,28 +203,31 @@ class SystemConsumers (
}
}
- def choose: IncomingMessageEnvelope = {
+ def choose (updateChooser: Boolean = true): IncomingMessageEnvelope = {
val envelopeFromChooser = chooser.choose
updateTimer(metrics.deserializationNs) {
if (envelopeFromChooser == null) {
- trace("Chooser returned null.")
+ trace("Chooser returned null.")
- metrics.choseNull.inc
+ metrics.choseNull.inc
- // Sleep for a while so we don't poll in a tight loop.
- timeout = noNewMessagesTimeout
+ // Sleep for a while so we don't poll in a tight loop.
+ timeout = noNewMessagesTimeout
} else {
- val systemStreamPartition = envelopeFromChooser.getSystemStreamPartition
+ val systemStreamPartition = envelopeFromChooser.getSystemStreamPartition
- trace("Chooser returned an incoming message envelope: %s" format envelopeFromChooser)
+ trace("Chooser returned an incoming message envelope: %s" format envelopeFromChooser)
- // Ok to give the chooser a new message from this stream.
- timeout = 0
- metrics.choseObject.inc
- metrics.systemStreamMessagesChosen(envelopeFromChooser.getSystemStreamPartition).inc
+ // Ok to give the chooser a new message from this stream.
+ timeout = 0
+ metrics.choseObject.inc
+ metrics.systemStreamMessagesChosen(envelopeFromChooser.getSystemStreamPartition).inc
- tryUpdate(systemStreamPartition)
+ if (updateChooser) {
+ trace("Update chooser for " + systemStreamPartition.getPartition)
+ tryUpdate(systemStreamPartition)
+ }
}
}
@@ -287,7 +290,7 @@ class SystemConsumers (
}
}
- private def tryUpdate(ssp: SystemStreamPartition) {
+ def tryUpdate(ssp: SystemStreamPartition) {
var updated = false
try {
updated = update(ssp)
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
new file mode 100644
index 0000000..ca913de
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.samza.Partition;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.config.Config;
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.container.SamzaContainerMetrics;
+import org.apache.samza.container.TaskInstance;
+import org.apache.samza.container.TaskInstanceExceptionHandler;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemConsumers;
+import org.apache.samza.system.SystemStreamPartition;
+import org.junit.Before;
+import org.junit.Test;
+import scala.collection.JavaConversions;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class TestAsyncRunLoop {
+
+ Map<TaskName, TaskInstance<AsyncStreamTask>> tasks;
+ ExecutorService executor;
+ SystemConsumers consumerMultiplexer;
+ SamzaContainerMetrics containerMetrics;
+ OffsetManager offsetManager;
+ long windowMs;
+ long commitMs;
+ long callbackTimeoutMs;
+ int maxMessagesInFlight;
+ TaskCoordinator.RequestScope commitRequest;
+ TaskCoordinator.RequestScope shutdownRequest;
+
+ Partition p0 = new Partition(0);
+ Partition p1 = new Partition(1);
+ TaskName taskName0 = new TaskName(p0.toString());
+ TaskName taskName1 = new TaskName(p1.toString());
+ SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
+ SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
+ IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+ IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+ IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+
+ TestTask task0;
+ TestTask task1;
+ TaskInstance<AsyncStreamTask> t0;
+ TaskInstance<AsyncStreamTask> t1;
+
+ AsyncRunLoop createRunLoop() {
+ return new AsyncRunLoop(tasks,
+ executor,
+ consumerMultiplexer,
+ maxMessagesInFlight,
+ windowMs,
+ commitMs,
+ callbackTimeoutMs,
+ containerMetrics);
+ }
+
+ TaskInstance<AsyncStreamTask> createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp) {
+ TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
+ scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConversions.asScalaSet(Collections.singleton(ssp)).toSet();
+ return new TaskInstance<AsyncStreamTask>(task, taskName, mock(Config.class), taskInstanceMetrics,
+ null, consumerMultiplexer, mock(TaskInstanceCollector.class), mock(SamzaContainerContext.class),
+ offsetManager, null, null, sspSet, new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()));
+ }
+
+ ExecutorService callbackExecutor;
+ void triggerCallback(final TestTask task, final TaskCallback callback, final boolean success) {
+ callbackExecutor.submit(new Runnable() {
+ @Override
+ public void run() {
+ if (task.code != null) {
+ task.code.run(callback);
+ }
+
+ task.completed.incrementAndGet();
+
+ if (success) {
+ callback.complete();
+ } else {
+ callback.failure(new Exception("process failure"));
+ }
+ }
+ });
+ }
+
+ interface TestCode {
+ void run(TaskCallback callback);
+ }
+
+ class TestTask implements AsyncStreamTask, WindowableTask {
+ boolean shutdown = false;
+ boolean commit = false;
+ boolean success;
+ int processed = 0;
+ volatile int windowCount = 0;
+
+ AtomicInteger completed = new AtomicInteger(0);
+ TestCode code = null;
+
+ TestTask(boolean success, boolean commit, boolean shutdown) {
+ this.success = success;
+ this.shutdown = shutdown;
+ this.commit = commit;
+ }
+
+ @Override
+ public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator,
+ TaskCallback callback) {
+
+ if (maxMessagesInFlight == 1) {
+ assertEquals(processed, completed.get());
+ }
+
+ processed++;
+
+ if (commit) {
+ coordinator.commit(commitRequest);
+ }
+
+ if (shutdown) {
+ coordinator.shutdown(shutdownRequest);
+ }
+ triggerCallback(this, callback, success);
+ }
+
+ @Override
+ public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+ windowCount++;
+
+ if (shutdown && windowCount == 4) {
+ coordinator.shutdown(shutdownRequest);
+ }
+ }
+ }
+
+ @Before
+ public void setup() {
+ executor = null;
+ consumerMultiplexer = mock(SystemConsumers.class);
+ windowMs = -1;
+ commitMs = -1;
+ maxMessagesInFlight = 1;
+ containerMetrics = new SamzaContainerMetrics("container", new MetricsRegistryMap());
+ callbackExecutor = Executors.newFixedThreadPool(2);
+ offsetManager = mock(OffsetManager.class);
+ shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
+
+ when(consumerMultiplexer.pollIntervalMs()).thenReturn(1000000);
+
+ tasks = new HashMap<>();
+ task0 = new TestTask(true, true, false);
+ task1 = new TestTask(true, false, true);
+ t0 = createTaskInstance(task0, taskName0, ssp0);
+ t1 = createTaskInstance(task1, taskName1, ssp1);
+ tasks.put(taskName0, t0);
+ tasks.put(taskName1, t1);
+ }
+
+ @Test
+ public void testProcessMultipleTasks() throws Exception {
+ AsyncRunLoop runLoop = createRunLoop();
+ when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+ runLoop.run();
+
+ callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+ assertEquals(1, task0.processed);
+ assertEquals(1, task0.completed.get());
+ assertEquals(1, task1.processed);
+ assertEquals(1, task1.completed.get());
+ assertEquals(2L, containerMetrics.envelopes().getCount());
+ assertEquals(2L, containerMetrics.processes().getCount());
+ }
+
+ @Test
+ public void testProcessInOrder() throws Exception {
+ AsyncRunLoop runLoop = createRunLoop();
+ when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null);
+ runLoop.run();
+
+ callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+ assertEquals(2, task0.processed);
+ assertEquals(2, task0.completed.get());
+ assertEquals(1, task1.processed);
+ assertEquals(1, task1.completed.get());
+ assertEquals(3L, containerMetrics.envelopes().getCount());
+ assertEquals(3L, containerMetrics.processes().getCount());
+ }
+
+ @Test
+ public void testProcessOutOfOrder() throws Exception {
+ maxMessagesInFlight = 2;
+
+ final CountDownLatch latch = new CountDownLatch(1);
+ task0.code = new TestCode() {
+ @Override
+ public void run(TaskCallback callback) {
+ IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).envelope;
+ if (envelope == envelope0) {
+ // process first message will wait till the second one is processed
+ try {
+ latch.await();
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ } else {
+ // second envelope complete first
+ assertEquals(0, task0.completed.get());
+ latch.countDown();
+ }
+ }
+ };
+
+ AsyncRunLoop runLoop = createRunLoop();
+ when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null);
+ runLoop.run();
+
+ callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+ assertEquals(2, task0.processed);
+ assertEquals(2, task0.completed.get());
+ assertEquals(1, task1.processed);
+ assertEquals(1, task1.completed.get());
+ assertEquals(3L, containerMetrics.envelopes().getCount());
+ assertEquals(3L, containerMetrics.processes().getCount());
+ }
+
+ @Test
+ public void testWindow() throws Exception {
+ windowMs = 1;
+
+ AsyncRunLoop runLoop = createRunLoop();
+ when(consumerMultiplexer.choose(false)).thenReturn(null);
+ runLoop.run();
+
+ callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+ assertEquals(4, task1.windowCount);
+ }
+
+ @Test
+ public void testCommitSingleTask() throws Exception {
+ commitRequest = TaskCoordinator.RequestScope.CURRENT_TASK;
+
+ AsyncRunLoop runLoop = createRunLoop();
+ //have a null message in between to make sure task0 finishes processing and invoke the commit
+ when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(null).thenReturn(envelope1).thenReturn(null);
+ runLoop.run();
+
+ callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+ verify(offsetManager).checkpoint(taskName0);
+ verify(offsetManager, never()).checkpoint(taskName1);
+ }
+
+ @Test
+ public void testCommitAllTasks() throws Exception {
+ commitRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
+
+ AsyncRunLoop runLoop = createRunLoop();
+ //have a null message in between to make sure task0 finishes processing and invoke the commit
+ when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(null).thenReturn(envelope1).thenReturn(null);
+ runLoop.run();
+
+ callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+ verify(offsetManager).checkpoint(taskName0);
+ verify(offsetManager).checkpoint(taskName1);
+ }
+
+ @Test
+ public void testShutdownOnConsensus() throws Exception {
+ shutdownRequest = TaskCoordinator.RequestScope.CURRENT_TASK;
+
+ tasks = new HashMap<>();
+ task0 = new TestTask(true, true, true);
+ task1 = new TestTask(true, false, true);
+ t0 = createTaskInstance(task0, taskName0, ssp0);
+ t1 = createTaskInstance(task1, taskName1, ssp1);
+ tasks.put(taskName0, t0);
+ tasks.put(taskName1, t1);
+
+ AsyncRunLoop runLoop = createRunLoop();
+ // consensus is reached after envelope1 is processed.
+ when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+ runLoop.run();
+
+ callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+ assertEquals(1, task0.processed);
+ assertEquals(1, task0.completed.get());
+ assertEquals(1, task1.processed);
+ assertEquals(1, task1.completed.get());
+ assertEquals(2L, containerMetrics.envelopes().getCount());
+ assertEquals(2L, containerMetrics.processes().getCount());
+ }
+}
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java b/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java
new file mode 100644
index 0000000..99e1e18
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import org.apache.samza.config.Config;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+
+public class TestAsyncStreamAdapter {
+ TestStreamTask task;
+ AsyncStreamTaskAdapter taskAdaptor;
+ Exception e;
+ IncomingMessageEnvelope envelope;
+
+ class TestCallbackListener implements TaskCallbackListener {
+ boolean callbackComplete = false;
+ boolean callbackFailure = false;
+
+ @Override
+ public void onComplete(TaskCallback callback) {
+ callbackComplete = true;
+ }
+
+ @Override
+ public void onFailure(TaskCallback callback, Throwable t) {
+ callbackFailure = true;
+ }
+ }
+
+ class TestStreamTask implements StreamTask, InitableTask, ClosableTask, WindowableTask {
+ boolean inited = false;
+ boolean closed = false;
+ boolean processed = false;
+ boolean windowed = false;
+
+ @Override
+ public void close() throws Exception {
+ closed = true;
+ }
+
+ @Override
+ public void init(Config config, TaskContext context) throws Exception {
+ inited = true;
+ }
+
+ @Override
+ public void process(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+ processed = true;
+ if (e != null) {
+ throw e;
+ }
+ }
+
+ @Override
+ public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+ windowed = true;
+ }
+ }
+
+ @Before
+ public void setup() {
+ task = new TestStreamTask();
+ e = null;
+ envelope = mock(IncomingMessageEnvelope.class);
+ }
+
+ @Test
+ public void testAdapterWithoutThreadPool() throws Exception {
+ taskAdaptor = new AsyncStreamTaskAdapter(task, null);
+ TestCallbackListener listener = new TestCallbackListener();
+ TaskCallback callback = new TaskCallbackImpl(listener, null, envelope, null, 0L);
+
+ taskAdaptor.init(null, null);
+ assertTrue(task.inited);
+
+ taskAdaptor.processAsync(null, null, null, callback);
+ assertTrue(task.processed);
+ assertTrue(listener.callbackComplete);
+
+ e = new Exception("dummy exception");
+ taskAdaptor.processAsync(null, null, null, callback);
+ assertTrue(listener.callbackFailure);
+
+ taskAdaptor.window(null, null);
+ assertTrue(task.windowed);
+
+ taskAdaptor.close();
+ assertTrue(task.closed);
+ }
+
+ @Test
+ public void testAdapterWithThreadPool() throws Exception {
+ TestCallbackListener listener1 = new TestCallbackListener();
+ TaskCallback callback1 = new TaskCallbackImpl(listener1, null, envelope, null, 0L);
+
+ TestCallbackListener listener2 = new TestCallbackListener();
+ TaskCallback callback2 = new TaskCallbackImpl(listener2, null, envelope, null, 1L);
+
+ ExecutorService executor = Executors.newFixedThreadPool(2);
+ taskAdaptor = new AsyncStreamTaskAdapter(task, executor);
+ taskAdaptor.processAsync(null, null, null, callback1);
+ taskAdaptor.processAsync(null, null, null, callback2);
+
+ executor.awaitTermination(1, TimeUnit.SECONDS);
+ assertTrue(listener1.callbackComplete);
+ assertTrue(listener2.callbackComplete);
+
+ e = new Exception("dummy exception");
+ taskAdaptor.processAsync(null, null, null, callback1);
+ taskAdaptor.processAsync(null, null, null, callback2);
+
+ executor.awaitTermination(1, TimeUnit.SECONDS);
+ assertTrue(listener1.callbackFailure);
+ assertTrue(listener2.callbackFailure);
+ }
+}
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java b/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java
new file mode 100644
index 0000000..d9c68d7
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.HashSet;
+import java.util.Set;
+import org.apache.samza.container.TaskName;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class TestCoordinatorRequests {
+ CoordinatorRequests coordinatorRequests;
+ TaskName taskA = new TaskName("a");
+ TaskName taskB = new TaskName("b");
+ TaskName taskC = new TaskName("c");
+
+
+ @Before
+ public void setup() {
+ Set<TaskName> taskNames = new HashSet<>();
+ taskNames.add(taskA);
+ taskNames.add(taskB);
+ taskNames.add(taskC);
+
+ coordinatorRequests = new CoordinatorRequests(taskNames);
+ }
+
+ @Test
+ public void testUpdateCommit() {
+ ReadableCoordinator coordinator = new ReadableCoordinator(taskA);
+ coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+ coordinatorRequests.update(coordinator);
+ assertTrue(coordinatorRequests.commitRequests().contains(taskA));
+
+ coordinator = new ReadableCoordinator(taskC);
+ coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+ coordinatorRequests.update(coordinator);
+ assertTrue(coordinatorRequests.commitRequests().contains(taskC));
+ assertFalse(coordinatorRequests.commitRequests().contains(taskB));
+ assertTrue(coordinatorRequests.commitRequests().size() == 2);
+
+ coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+ coordinatorRequests.update(coordinator);
+ assertTrue(coordinatorRequests.commitRequests().contains(taskB));
+ assertTrue(coordinatorRequests.commitRequests().size() == 3);
+ }
+
+ @Test
+ public void testUpdateShutdownOnConsensus() {
+ ReadableCoordinator coordinator = new ReadableCoordinator(taskA);
+ coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+ coordinatorRequests.update(coordinator);
+ assertFalse(coordinatorRequests.shouldShutdownNow());
+
+ coordinator = new ReadableCoordinator(taskB);
+ coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+ coordinatorRequests.update(coordinator);
+ assertFalse(coordinatorRequests.shouldShutdownNow());
+
+ coordinator = new ReadableCoordinator(taskC);
+ coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+ coordinatorRequests.update(coordinator);
+ assertTrue(coordinatorRequests.shouldShutdownNow());
+ }
+
+ @Test
+ public void testUpdateShutdownNow() {
+ ReadableCoordinator coordinator = new ReadableCoordinator(taskA);
+ coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+ coordinatorRequests.update(coordinator);
+ assertTrue(coordinatorRequests.shouldShutdownNow());
+ }
+}
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java
new file mode 100644
index 0000000..f1dbf35
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+
+public class TestTaskCallbackImpl {
+
+ TaskCallbackListener listener = null;
+ AtomicInteger completeCount;
+ AtomicInteger failureCount;
+ TaskCallback callback = null;
+ Throwable throwable = null;
+
+ @Before
+ public void setup() {
+ completeCount = new AtomicInteger(0);
+ failureCount = new AtomicInteger(0);
+ throwable = null;
+
+ listener = new TaskCallbackListener() {
+
+ @Override
+ public void onComplete(TaskCallback callback) {
+ completeCount.incrementAndGet();
+ }
+
+ @Override
+ public void onFailure(TaskCallback callback, Throwable t) {
+ throwable = t;
+ failureCount.incrementAndGet();
+ }
+ };
+
+ callback = new TaskCallbackImpl(listener, null, mock(IncomingMessageEnvelope.class), null, 0);
+ }
+
+ @Test
+ public void testComplete() {
+ callback.complete();
+ assertEquals(1L, completeCount.get());
+ assertEquals(0L, failureCount.get());
+ }
+
+ @Test
+ public void testFailure() {
+ callback.failure(new Exception("dummy exception"));
+ assertEquals(0L, completeCount.get());
+ assertEquals(1L, failureCount.get());
+ }
+
+ @Test
+ public void testCallbackMultipleComplete() {
+ callback.complete();
+ assertEquals(1L, completeCount.get());
+
+ callback.complete();
+ assertEquals(1L, failureCount.get());
+ assertTrue(throwable instanceof IllegalStateException);
+ }
+
+ @Test
+ public void testCallbackFailureAfterComplete() {
+ callback.complete();
+ assertEquals(1L, completeCount.get());
+
+ callback.failure(new Exception("dummy exception"));
+ assertEquals(1L, failureCount.get());
+ assertTrue(throwable instanceof IllegalStateException);
+ }
+
+
+ @Test
+ public void testMultithreadedCallbacks() throws Exception {
+ final CyclicBarrier barrier = new CyclicBarrier(2);
+ ExecutorService executor = Executors.newFixedThreadPool(2);
+
+ for (int i = 0; i < 2; i++) {
+ executor.submit(new Runnable() {
+ @Override
+ public void run() {
+ try {
+ barrier.await();
+ callback.complete();
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+ });
+ }
+ executor.awaitTermination(1, TimeUnit.SECONDS);
+ assertEquals(1L, completeCount.get());
+ assertEquals(1L, failureCount.get());
+ assertTrue(throwable instanceof IllegalStateException);
+ }
+}
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
new file mode 100644
index 0000000..d7110f3
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import org.apache.samza.Partition;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+
+public class TestTaskCallbackManager {
+ TaskCallbackManager callbackManager = null;
+ TaskCallbackListener listener = null;
+
+ @Before
+ public void setup() {
+ TaskInstanceMetrics metrics = new TaskInstanceMetrics("Partition 0", new MetricsRegistryMap());
+ listener = new TaskCallbackListener() {
+ @Override
+ public void onComplete(TaskCallback callback) {
+ }
+ @Override
+ public void onFailure(TaskCallback callback, Throwable t) {
+ }
+ };
+ callbackManager = new TaskCallbackManager(listener, metrics, null, -1);
+
+ }
+
+ @Test
+ public void testCreateCallback() {
+ TaskCallbackImpl callback = callbackManager.createCallback(new TaskName("Partition 0"), null, null);
+ assertTrue(callback.matchSeqNum(0));
+
+ callback = callbackManager.createCallback(new TaskName("Partition 0"), null, null);
+ assertTrue(callback.matchSeqNum(1));
+ }
+
+ @Test
+ public void testUpdateCallbackInOrder() {
+ TaskName taskName = new TaskName("Partition 0");
+ SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0));
+ ReadableCoordinator coordinator = new ReadableCoordinator(taskName);
+
+ IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null);
+ TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0);
+ TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback0, true);
+ assertTrue(callbackToCommit.matchSeqNum(0));
+ assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+ assertEquals("0", callbackToCommit.envelope.getOffset());
+
+ IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null);
+ TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator, 1);
+ callbackToCommit = callbackManager.updateCallback(callback1, true);
+ assertTrue(callbackToCommit.matchSeqNum(1));
+ assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+ assertEquals("1", callbackToCommit.envelope.getOffset());
+ }
+
+ @Test
+ public void testUpdateCallbackOutofOrder() {
+ TaskName taskName = new TaskName("Partition 0");
+ SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0));
+ ReadableCoordinator coordinator = new ReadableCoordinator(taskName);
+
+ // simulate out of order
+ IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp, "2", null, null);
+ TaskCallbackImpl callback2 = new TaskCallbackImpl(listener, taskName, envelope2, coordinator, 2);
+ TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback2, true);
+ assertNull(callbackToCommit);
+
+ IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null);
+ TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator, 1);
+ callbackToCommit = callbackManager.updateCallback(callback1, true);
+ assertNull(callbackToCommit);
+
+ IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null);
+ TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0);
+ callbackToCommit = callbackManager.updateCallback(callback0, true);
+ assertTrue(callbackToCommit.matchSeqNum(2));
+ assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+ assertEquals("2", callbackToCommit.envelope.getOffset());
+ }
+
+ @Test
+ public void testUpdateCallbackWithCoordinatorRequests() {
+ TaskName taskName = new TaskName("Partition 0");
+ SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0));
+
+
+ // simulate out of order
+ IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp, "2", null, null);
+ ReadableCoordinator coordinator2 = new ReadableCoordinator(taskName);
+ coordinator2.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+ TaskCallbackImpl callback2 = new TaskCallbackImpl(listener, taskName, envelope2, coordinator2, 2);
+ TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback2, true);
+ assertNull(callbackToCommit);
+
+ IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null);
+ ReadableCoordinator coordinator1 = new ReadableCoordinator(taskName);
+ coordinator1.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+ TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator1, 1);
+ callbackToCommit = callbackManager.updateCallback(callback1, true);
+ assertNull(callbackToCommit);
+
+ IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null);
+ ReadableCoordinator coordinator = new ReadableCoordinator(taskName);
+ TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0);
+ callbackToCommit = callbackManager.updateCallback(callback0, true);
+ assertTrue(callbackToCommit.matchSeqNum(1));
+ assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+ assertEquals("1", callbackToCommit.envelope.getOffset());
+ assertTrue(callbackToCommit.coordinator.requestedShutdownNow());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
index e280daa..aa1a8d6 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
@@ -20,22 +20,26 @@
package org.apache.samza.container
-import org.apache.samza.metrics.{Timer, SlidingTimeWindowReservoir, MetricsRegistryMap}
+import org.apache.samza.Partition
+import org.apache.samza.metrics.MetricsRegistryMap
+import org.apache.samza.metrics.SlidingTimeWindowReservoir
+import org.apache.samza.metrics.Timer
+import org.apache.samza.system.IncomingMessageEnvelope
+import org.apache.samza.system.SystemConsumers
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.task.TaskCoordinator.RequestScope
+import org.apache.samza.task.ReadableCoordinator
+import org.apache.samza.task.StreamTask
import org.apache.samza.util.Clock
-import org.junit.Test
import org.junit.Assert._
+import org.junit.Test
import org.mockito.Matchers
import org.mockito.Mockito._
-import org.mockito.internal.util.reflection.Whitebox
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.junit.AssertionsForJUnit
-import org.scalatest.{Matchers => ScalaTestMatchers}
import org.scalatest.mock.MockitoSugar
-import org.apache.samza.Partition
-import org.apache.samza.system.{ IncomingMessageEnvelope, SystemConsumers, SystemStreamPartition }
-import org.apache.samza.task.ReadableCoordinator
-import org.apache.samza.task.TaskCoordinator.RequestScope
+import org.scalatest.{Matchers => ScalaTestMatchers}
class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMatchers {
class StopRunLoop extends RuntimeException
@@ -49,12 +53,12 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
val envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0")
val envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1")
- def getMockTaskInstances: Map[TaskName, TaskInstance] = {
- val ti0 = mock[TaskInstance]
+ def getMockTaskInstances: Map[TaskName, TaskInstance[StreamTask]] = {
+ val ti0 = mock[TaskInstance[StreamTask]]
when(ti0.systemStreamPartitions).thenReturn(Set(ssp0))
when(ti0.taskName).thenReturn(taskName0)
- val ti1 = mock[TaskInstance]
+ val ti1 = mock[TaskInstance[StreamTask]]
when(ti1.systemStreamPartitions).thenReturn(Set(ssp1))
when(ti1.taskName).thenReturn(taskName1)
@@ -67,10 +71,10 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
val consumers = mock[SystemConsumers]
val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics)
- when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
+ when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
intercept[StopRunLoop] { runLoop.run }
- verify(taskInstances(taskName0)).process(Matchers.eq(envelope0), anyObject)
- verify(taskInstances(taskName1)).process(Matchers.eq(envelope1), anyObject)
+ verify(taskInstances(taskName0)).process(Matchers.eq(envelope0), anyObject, anyObject)
+ verify(taskInstances(taskName1)).process(Matchers.eq(envelope1), anyObject, anyObject)
runLoop.metrics.envelopes.getCount should equal(2L)
runLoop.metrics.nullEnvelopes.getCount should equal(0L)
}
@@ -80,7 +84,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
val consumers = mock[SystemConsumers]
val map = getMockTaskInstances - taskName1 // This test only needs p0
val runLoop = new RunLoop(map, consumers, new SamzaContainerMetrics)
- when(consumers.choose).thenReturn(null).thenReturn(null).thenThrow(new StopRunLoop)
+ when(consumers.choose()).thenReturn(null).thenReturn(null).thenThrow(new StopRunLoop)
intercept[StopRunLoop] { runLoop.run }
runLoop.metrics.envelopes.getCount should equal(0L)
runLoop.metrics.nullEnvelopes.getCount should equal(2L)
@@ -90,7 +94,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
def testWindowAndCommitAreCalledRegularly {
var now = 1400000000000L
val consumers = mock[SystemConsumers]
- when(consumers.choose).thenReturn(envelope0)
+ when(consumers.choose()).thenReturn(envelope0)
val runLoop = new RunLoop(
taskInstances = getMockTaskInstances,
@@ -118,7 +122,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
val consumers = mock[SystemConsumers]
val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
- when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
+ when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.commit(RequestScope.CURRENT_TASK))
intercept[StopRunLoop] { runLoop.run }
@@ -132,7 +136,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
val consumers = mock[SystemConsumers]
val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
- when(consumers.choose).thenReturn(envelope0).thenThrow(new StopRunLoop)
+ when(consumers.choose()).thenReturn(envelope0).thenThrow(new StopRunLoop)
stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.commit(RequestScope.ALL_TASKS_IN_CONTAINER))
intercept[StopRunLoop] { runLoop.run }
@@ -146,13 +150,13 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
val consumers = mock[SystemConsumers]
val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
- when(consumers.choose).thenReturn(envelope0).thenReturn(envelope0).thenReturn(envelope1)
+ when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope0).thenReturn(envelope1)
stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK))
stubProcess(taskInstances(taskName1), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK))
runLoop.run
- verify(taskInstances(taskName0), times(2)).process(Matchers.eq(envelope0), anyObject)
- verify(taskInstances(taskName1), times(1)).process(Matchers.eq(envelope1), anyObject)
+ verify(taskInstances(taskName0), times(2)).process(Matchers.eq(envelope0), anyObject, anyObject)
+ verify(taskInstances(taskName1), times(1)).process(Matchers.eq(envelope1), anyObject, anyObject)
}
@Test
@@ -161,19 +165,19 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
val consumers = mock[SystemConsumers]
val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
- when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1)
+ when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1)
stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.shutdown(RequestScope.ALL_TASKS_IN_CONTAINER))
runLoop.run
- verify(taskInstances(taskName0), times(1)).process(anyObject, anyObject)
- verify(taskInstances(taskName1), times(0)).process(anyObject, anyObject)
+ verify(taskInstances(taskName0), times(1)).process(anyObject, anyObject, anyObject)
+ verify(taskInstances(taskName1), times(0)).process(anyObject, anyObject, anyObject)
}
def anyObject[T] = Matchers.anyObject.asInstanceOf[T]
// Stub out TaskInstance.process. Mockito really doesn't make this easy. :(
- def stubProcess(taskInstance: TaskInstance, process: (IncomingMessageEnvelope, ReadableCoordinator) => Unit) {
- when(taskInstance.process(anyObject, anyObject)).thenAnswer(new Answer[Unit]() {
+ def stubProcess(taskInstance: TaskInstance[StreamTask], process: (IncomingMessageEnvelope, ReadableCoordinator) => Unit) {
+ when(taskInstance.process(anyObject, anyObject, anyObject)).thenAnswer(new Answer[Unit]() {
override def answer(invocation: InvocationOnMock) {
val envelope = invocation.getArguments()(0).asInstanceOf[IncomingMessageEnvelope]
val coordinator = invocation.getArguments()(1).asInstanceOf[ReadableCoordinator]
@@ -186,7 +190,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
def testUpdateTimerCorrectly {
var now = 0L
val consumers = mock[SystemConsumers]
- when(consumers.choose).thenReturn(envelope0)
+ when(consumers.choose()).thenReturn(envelope0)
val clock = new Clock {
var c = 0L
def currentTimeMillis: Long = {
@@ -263,9 +267,9 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
@Test
def testGetSystemStreamPartitionToTaskInstancesMapping {
- val ti0 = mock[TaskInstance]
- val ti1 = mock[TaskInstance]
- val ti2 = mock[TaskInstance]
+ val ti0 = mock[TaskInstance[StreamTask]]
+ val ti1 = mock[TaskInstance[StreamTask]]
+ val ti2 = mock[TaskInstance[StreamTask]]
when(ti0.systemStreamPartitions).thenReturn(Set(ssp0))
when(ti1.systemStreamPartitions).thenReturn(Set(ssp1))
when(ti2.systemStreamPartitions).thenReturn(Set(ssp1))
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
index 1358fdd..cff6b96 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
@@ -180,7 +180,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
new SerdeManager)
val collector = new TaskInstanceCollector(producerMultiplexer)
val containerContext = new SamzaContainerContext(0, config, Set[TaskName](taskName))
- val taskInstance: TaskInstance = new TaskInstance(
+ val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
task,
taskName,
config,
@@ -261,7 +261,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
}
})
- val taskInstance: TaskInstance = new TaskInstance(
+ val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
task,
taskName,
config,
@@ -314,4 +314,4 @@ class MockJobServlet(exceptionLimit: Int, jobModelRef: AtomicReference[JobModel]
jobModel
}
}
-}
\ No newline at end of file
+}
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
index 5457f0e..3c83529 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
@@ -71,7 +71,7 @@ class TestTaskInstance {
val taskName = new TaskName("taskName")
val collector = new TaskInstanceCollector(producerMultiplexer)
val containerContext = new SamzaContainerContext(0, config, Set[TaskName](taskName))
- val taskInstance: TaskInstance = new TaskInstance(
+ val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
task,
taskName,
config,
@@ -169,7 +169,7 @@ class TestTaskInstance {
val registry = new MetricsRegistryMap
val taskMetrics = new TaskInstanceMetrics(registry = registry)
- val taskInstance: TaskInstance = new TaskInstance(
+ val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
task,
taskName,
config,
@@ -226,7 +226,7 @@ class TestTaskInstance {
val registry = new MetricsRegistryMap
val taskMetrics = new TaskInstanceMetrics(registry = registry)
- val taskInstance: TaskInstance = new TaskInstance(
+ val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
task,
taskName,
config,
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala b/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
index 09da62e..db2249b 100644
--- a/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
+++ b/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
@@ -54,14 +54,14 @@ class TestSystemConsumers {
consumer.setResponseSizes(numEnvelopes)
// Choose to trigger a refresh with data.
- assertNull(consumers.choose)
+ assertNull(consumers.choose())
// 2: First on start, second on choose.
assertEquals(2, consumer.polls)
assertEquals(2, consumer.lastPoll.size)
assertTrue(consumer.lastPoll.contains(systemStreamPartition0))
assertTrue(consumer.lastPoll.contains(systemStreamPartition1))
- assertEquals(envelope, consumers.choose)
- assertEquals(envelope, consumers.choose)
+ assertEquals(envelope, consumers.choose())
+ assertEquals(envelope, consumers.choose())
// We aren't polling because we're getting non-null envelopes.
assertEquals(2, consumer.polls)
@@ -69,7 +69,7 @@ class TestSystemConsumers {
// messages.
now = SystemConsumers.DEFAULT_POLL_INTERVAL_MS
- assertEquals(envelope, consumers.choose)
+ assertEquals(envelope, consumers.choose())
// We polled even though there are still 997 messages in the unprocessed
// message buffer.
@@ -82,11 +82,11 @@ class TestSystemConsumers {
// Now drain all messages for SSP0. There should be exactly 997 messages,
// since we have chosen 3 already, and we started with 1000.
(0 until (numEnvelopes - 3)).foreach { i =>
- assertEquals(envelope, consumers.choose)
+ assertEquals(envelope, consumers.choose())
}
// Nothing left. Should trigger a poll here.
- assertNull(consumers.choose)
+ assertNull(consumers.choose())
assertEquals(4, consumer.polls)
assertEquals(2, consumer.lastPoll.size)
@@ -117,31 +117,31 @@ class TestSystemConsumers {
consumer.setResponseSizes(1)
// Choose to trigger a refresh with data.
- assertNull(consumers.choose)
+ assertNull(consumers.choose())
// Choose should have triggered a second poll, since no messages are available.
assertEquals(2, consumer.polls)
// Choose a few times. This time there is no data.
- assertEquals(envelope, consumers.choose)
- assertNull(consumers.choose)
- assertNull(consumers.choose)
+ assertEquals(envelope, consumers.choose())
+ assertNull(consumers.choose())
+ assertNull(consumers.choose())
// Return more than one message this time.
consumer.setResponseSizes(2)
// Choose to trigger a refresh with data.
- assertNull(consumers.choose)
+ assertNull(consumers.choose())
// Increase clock interval.
now = SystemConsumers.DEFAULT_POLL_INTERVAL_MS
// We get two messages now.
- assertEquals(envelope, consumers.choose)
+ assertEquals(envelope, consumers.choose())
// Should not poll even though clock interval increases past interval threshold.
assertEquals(2, consumer.polls)
- assertEquals(envelope, consumers.choose)
- assertNull(consumers.choose)
+ assertEquals(envelope, consumers.choose())
+ assertNull(consumers.choose())
}
@Test
@@ -238,7 +238,7 @@ class TestSystemConsumers {
var caughtRightException = false
try {
- consumers.choose
+ consumers.choose()
} catch {
case e: SystemConsumersException => caughtRightException = true
case _: Throwable => caughtRightException = false
@@ -256,13 +256,13 @@ class TestSystemConsumers {
var notThrowException = true;
try {
- consumers2.choose
+ consumers2.choose()
} catch {
case e: Throwable => notThrowException = false
}
assertTrue("it should not throw any exception", notThrowException)
- var msgEnvelope = Some(consumers2.choose)
+ var msgEnvelope = Some(consumers2.choose())
assertTrue("Consumer did not succeed in receiving the second message after Serde exception in choose", msgEnvelope.get != null)
consumers2.stop
@@ -279,7 +279,7 @@ class TestSystemConsumers {
assertTrue("SystemConsumer start should not throw any Serde exception", notThrowException)
msgEnvelope = null
- msgEnvelope = Some(consumers2.choose)
+ msgEnvelope = Some(consumers2.choose())
assertTrue("Consumer did not succeed in receiving the second message after Serde exception in poll", msgEnvelope.get != null)
consumers2.stop
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala
----------------------------------------------------------------------
diff --git a/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala b/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala
index 1f4b5c4..24bc8b5 100644
--- a/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala
+++ b/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala
@@ -36,6 +36,7 @@ class HdfsSystemProducer(
val clock: () => Long = () => System.currentTimeMillis) extends SystemProducer with Logging with TimerUtils {
val dfs = FileSystem.get(new Configuration(true))
val writers: MMap[String, HdfsWriter[_]] = MMap.empty[String, HdfsWriter[_]]
+ private val lock = new Object //synchronization lock for thread safe access
def start(): Unit = {
info("entering HdfsSystemProducer.start() call for system: " + systemName + ", client: " + clientId)
@@ -43,52 +44,65 @@ class HdfsSystemProducer(
def stop(): Unit = {
info("entering HdfsSystemProducer.stop() for system: " + systemName + ", client: " + clientId)
- writers.values.map { _.close }
- dfs.close
+
+ lock.synchronized {
+ writers.values.map(_.close)
+ dfs.close
+ }
}
def register(source: String): Unit = {
info("entering HdfsSystemProducer.register(" + source + ") " +
"call for system: " + systemName + ", client: " + clientId)
- writers += (source -> HdfsWriter.getInstance(dfs, systemName, config))
+
+ lock.synchronized {
+ writers += (source -> HdfsWriter.getInstance(dfs, systemName, config))
+ }
}
def flush(source: String): Unit = {
debug("entering HdfsSystemProducer.flush(" + source + ") " +
"call for system: " + systemName + ", client: " + clientId)
- try {
- metrics.flushes.inc
- updateTimer(metrics.flushMs) { writers.get(source).head.flush }
- metrics.flushSuccess.inc
- } catch {
- case e: Exception => {
- metrics.flushFailed.inc
- warn("Exception thrown while client " + clientId + " flushed HDFS out stream, msg: " + e.getMessage)
- debug("Detailed message from exception thrown by client " + clientId + " in HDFS flush: ", e)
- writers.get(source).head.close
- throw e
+
+ metrics.flushes.inc
+ lock.synchronized {
+ try {
+ updateTimer(metrics.flushMs) {
+ writers.get(source).head.flush
+ }
+ } catch {
+ case e: Exception => {
+ metrics.flushFailed.inc
+ warn("Exception thrown while client " + clientId + " flushed HDFS out stream, msg: " + e.getMessage)
+ debug("Detailed message from exception thrown by client " + clientId + " in HDFS flush: ", e)
+ writers.get(source).head.close
+ throw e
+ }
}
}
+ metrics.flushSuccess.inc
}
def send(source: String, ome: OutgoingMessageEnvelope) = {
debug("entering HdfsSystemProducer.send(source = " + source + ", envelope) " +
"call for system: " + systemName + ", client: " + clientId)
+
metrics.sends.inc
- try {
- updateTimer(metrics.sendMs) {
- writers.get(source).head.write(ome)
- }
- metrics.sendSuccess.inc
- } catch {
- case e: Exception => {
- metrics.sendFailed.inc
- warn("Exception thrown while client " + clientId + " wrote to HDFS, msg: " + e.getMessage)
- debug("Detailed message from exception thrown by client " + clientId + " in HDFS write: ", e)
- writers.get(source).head.close
- throw e
+ lock.synchronized {
+ try {
+ updateTimer(metrics.sendMs) {
+ writers.get(source).head.write(ome)
+ }
+ } catch {
+ case e: Exception => {
+ metrics.sendFailed.inc
+ warn("Exception thrown while client " + clientId + " wrote to HDFS, msg: " + e.getMessage)
+ debug("Detailed message from exception thrown by client " + clientId + " in HDFS write: ", e)
+ writers.get(source).head.close
+ throw e
+ }
}
}
+ metrics.sendSuccess.inc
}
-
-}
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala
----------------------------------------------------------------------
diff --git a/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala b/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala
index 5e8cc65..5d2641a 100644
--- a/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala
+++ b/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala
@@ -140,6 +140,7 @@ class KafkaCheckpointMigration extends MigrationPlan with Logging {
def migrationCompletionMark(coordinatorSystemProducer: CoordinatorStreamSystemProducer) = {
info("Marking completion of migration %s" format migrationKey)
val message = new SetMigrationMetaMessage(source, migrationKey, migrationVal)
+ coordinatorSystemProducer.register(source)
coordinatorSystemProducer.start()
coordinatorSystemProducer.send(message)
coordinatorSystemProducer.stop()