You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@openwhisk.apache.org by ni...@apache.org on 2021/05/28 01:14:30 UTC

[openwhisk] branch master updated: [New Scheduler] Add container message consumer (#5111)

This is an automated email from the ASF dual-hosted git repository.

ningyougang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/openwhisk.git


The following commit(s) were added to refs/heads/master by this push:
     new b818c3b  [New Scheduler] Add container message consumer (#5111)
b818c3b is described below

commit b818c3b3e8bd3fa9ac7742d1b8c051ec09b76ae2
Author: Seonghyun Oh <oh...@navercorp.com>
AuthorDate: Fri May 28 10:14:11 2021 +0900

    [New Scheduler] Add container message consumer (#5111)
    
    * Add container message consumer
    
    * Reformat code
    
    * Fix test case error
    
    Co-authored-by: ning.yougang <ni...@navercorp.com>
---
 .../apache/openwhisk/core/connector/Message.scala  | 135 ++++++---
 .../org/apache/openwhisk/core/entity/Size.scala    |  41 ++-
 .../core/invoker/ContainerMessageConsumer.scala    | 132 +++++++++
 .../test/ContainerMessageConsumerTests.scala       | 328 +++++++++++++++++++++
 4 files changed, 586 insertions(+), 50 deletions(-)

diff --git a/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala b/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala
index 88f4f11..ba05c17 100644
--- a/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala
+++ b/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala
@@ -60,7 +60,7 @@ case class ActivationMessage(override val transid: TransactionId,
                              lockedArgs: Map[String, String] = Map.empty,
                              cause: Option[ActivationId] = None,
                              traceContext: Option[Map[String, String]] = None)
-    extends Message {
+  extends Message {
 
   override def serialize = ActivationMessage.serdes.write(this).compactPrint
 
@@ -78,6 +78,7 @@ case class ActivationMessage(override val transid: TransactionId,
  */
 abstract class AcknowledegmentMessage(private val tid: TransactionId) extends Message {
   override val transid: TransactionId = tid
+
   override def serialize: String = AcknowledegmentMessage.serdes.write(this).compactPrint
 
   /** Pithy descriptor for logging. */
@@ -115,17 +116,23 @@ abstract class AcknowledegmentMessage(private val tid: TransactionId) extends Me
  * The constructor is private so that callers must use the more restrictive constructors which ensure the respose is always
  * Right when this message is created.
  */
-case class CombinedCompletionAndResultMessage private (override val transid: TransactionId,
-                                                       response: Either[ActivationId, WhiskActivation],
-                                                       override val isSystemError: Option[Boolean],
-                                                       instance: InstanceId)
-    extends AcknowledegmentMessage(transid) {
+case class CombinedCompletionAndResultMessage private(override val transid: TransactionId,
+                                                      response: Either[ActivationId, WhiskActivation],
+                                                      override val isSystemError: Option[Boolean],
+                                                      instance: InstanceId)
+  extends AcknowledegmentMessage(transid) {
   override def messageType = "combined"
+
   override def result = Some(response)
+
   override def isSlotFree = Some(instance)
+
   override def activationId = response.fold(identity, _.activationId)
+
   override def toJson = CombinedCompletionAndResultMessage.serdes.write(this)
+
   override def shrink = copy(response = response.flatMap(a => Left(a.activationId)))
+
   override def toString = activationId.asString
 }
 
@@ -135,16 +142,21 @@ case class CombinedCompletionAndResultMessage private (override val transid: Tra
  * phase notification to the load balancer where an invoker first sends a `ResultMessage` and later sends the
  * `CompletionMessage`.
  */
-case class CompletionMessage private (override val transid: TransactionId,
-                                      override val activationId: ActivationId,
-                                      override val isSystemError: Option[Boolean],
-                                      instance: InstanceId)
-    extends AcknowledegmentMessage(transid) {
+case class CompletionMessage private(override val transid: TransactionId,
+                                     override val activationId: ActivationId,
+                                     override val isSystemError: Option[Boolean],
+                                     instance: InstanceId)
+  extends AcknowledegmentMessage(transid) {
   override def messageType = "completion"
+
   override def result = None
+
   override def isSlotFree = Some(instance)
+
   override def toJson = CompletionMessage.serdes.write(this)
+
   override def shrink = this
+
   override def toString = activationId.asString
 }
 
@@ -156,15 +168,22 @@ case class CompletionMessage private (override val transid: TransactionId,
  * The constructor is private so that callers must use the more restrictive constructors which ensure the respose is always
  * Right when this message is created.
  */
-case class ResultMessage private (override val transid: TransactionId, response: Either[ActivationId, WhiskActivation])
-    extends AcknowledegmentMessage(transid) {
+case class ResultMessage private(override val transid: TransactionId, response: Either[ActivationId, WhiskActivation])
+  extends AcknowledegmentMessage(transid) {
   override def messageType = "result"
+
   override def result = Some(response)
+
   override def isSlotFree = None
+
   override def isSystemError = response.fold(_ => None, a => Some(a.response.isWhiskError))
+
   override def activationId = response.fold(identity, _.activationId)
+
   override def toJson = ResultMessage.serdes.write(this)
+
   override def shrink = copy(response = response.flatMap(a => Left(a.activationId)))
+
   override def toString = activationId.asString
 }
 
@@ -234,7 +253,7 @@ object AcknowledegmentMessage extends DefaultJsonProtocol {
         Left(value.convertTo[ActivationId])
 
       case _: JsObject => Right(value.convertTo[WhiskActivation])
-      case _           => deserializationError("could not read ResultMessage")
+      case _ => deserializationError("could not read ResultMessage")
     }
   }
 
@@ -265,6 +284,7 @@ case class PingMessage(instance: InvokerInstanceId) extends Message {
 
 object PingMessage extends DefaultJsonProtocol {
   def parse(msg: String) = Try(serdes.read(msg.parseJson))
+
   implicit val serdes = jsonFormat(PingMessage.apply _, "name")
 }
 
@@ -276,7 +296,7 @@ object EventMessageBody extends DefaultJsonProtocol {
 
   implicit val format = new JsonFormat[EventMessageBody] {
     def write(eventMessageBody: EventMessageBody) = eventMessageBody match {
-      case m: Metric     => m.toJson
+      case m: Metric => m.toJson
       case a: Activation => a.toJson
     }
 
@@ -301,9 +321,11 @@ case class Activation(name: String,
                       causedBy: Option[String],
                       size: Option[Int] = None,
                       userDefinedStatusCode: Option[Int] = None)
-    extends EventMessageBody {
+  extends EventMessageBody {
   val typeName = Activation.typeName
+
   override def serialize = toJson.compactPrint
+
   def entityPath: FullyQualifiedEntityName = EntityPath(name).toFullyQualifiedEntityName
 
   def toJson = Activation.activationFormat.write(this)
@@ -327,12 +349,12 @@ object Activation extends DefaultJsonProtocol {
   private implicit val durationFormat = new RootJsonFormat[Duration] {
     override def write(obj: Duration): JsValue = obj match {
       case o if o.isFinite => JsNumber(o.toMillis)
-      case _               => JsNumber.zero
+      case _ => JsNumber.zero
     }
 
     override def read(json: JsValue): Duration = json match {
       case JsNumber(n) if n <= 0 => Duration.Zero
-      case JsNumber(n)           => toDuration(n.longValue)
+      case JsNumber(n) => toDuration(n.longValue)
     }
   }
 
@@ -352,7 +374,7 @@ object Activation extends DefaultJsonProtocol {
       "size",
       "userDefinedStatusCode")
 
-  /** Get "StatusCode" from result response set by action developer **/
+  /** Get "StatusCode" from result response set by action developer * */
   def userDefinedStatusCode(result: Option[JsValue]): Option[Int] = {
     val statusCode = JsHelpers
       .getFieldPath(result.get.asJsObject, ERROR_FIELD, "statusCode")
@@ -394,13 +416,17 @@ object Activation extends DefaultJsonProtocol {
 
 case class Metric(metricName: String, metricValue: Long) extends EventMessageBody {
   val typeName = "Metric"
+
   override def serialize = toJson.compactPrint
+
   def toJson = Metric.metricFormat.write(this).asJsObject
 }
 
 object Metric extends DefaultJsonProtocol {
   val typeName = "Metric"
+
   def parse(msg: String) = Try(metricFormat.read(msg.parseJson))
+
   implicit val metricFormat = jsonFormat(Metric.apply _, "metricName", "metricValue")
 }
 
@@ -411,7 +437,7 @@ case class EventMessage(source: String,
                         userId: UUID,
                         eventType: String,
                         timestamp: Long = System.currentTimeMillis())
-    extends Message {
+  extends Message {
   override def serialize = EventMessage.format.write(this).compactPrint
 }
 
@@ -434,7 +460,7 @@ case class InvokerResourceMessage(status: String,
                                   inProgressMemory: Long,
                                   tags: Seq[String],
                                   dedicatedNamespaces: Seq[String])
-    extends Message {
+  extends Message {
 
   /**
    * Serializes message to string. Must be idempotent.
@@ -444,6 +470,7 @@ case class InvokerResourceMessage(status: String,
 
 object InvokerResourceMessage extends DefaultJsonProtocol {
   def parse(msg: String): Try[InvokerResourceMessage] = Try(serdes.read(msg.parseJson))
+
   implicit val serdes =
     jsonFormat(
       InvokerResourceMessage.apply _,
@@ -462,23 +489,25 @@ object InvokerResourceMessage extends DefaultJsonProtocol {
  *
  * [
  * ...
- *    {
- *       "data": "RunningData",
- *       "fqn": "whisk.system/elasticsearch/status-alarm@0.0.2",
- *       "invocationNamespace": "style95",
- *       "status": "Running",
- *       "waitingActivation": 1
- *    },
+ * {
+ * "data": "RunningData",
+ * "fqn": "whisk.system/elasticsearch/status-alarm@0.0.2",
+ * "invocationNamespace": "style95",
+ * "status": "Running",
+ * "waitingActivation": 1
+ * },
  * ...
  * ]
  */
 object StatusQuery
+
 case class StatusData(invocationNamespace: String, fqn: String, waitingActivation: Int, status: String, data: String)
-    extends Message {
+  extends Message {
 
   override def serialize: String = StatusData.serdes.write(this).compactPrint
 
 }
+
 object StatusData extends DefaultJsonProtocol {
 
   implicit val serdes =
@@ -495,9 +524,10 @@ case class ContainerCreationMessage(override val transid: TransactionId,
                                     rpcPort: Int,
                                     retryCount: Int = 0,
                                     creationId: CreationId = CreationId.generate())
-    extends ContainerMessage(transid) {
+  extends ContainerMessage(transid) {
 
   override def toJson: JsValue = ContainerCreationMessage.serdes.write(this)
+
   override def serialize: String = toJson.compactPrint
 }
 
@@ -526,8 +556,9 @@ case class ContainerDeletionMessage(override val transid: TransactionId,
                                     action: FullyQualifiedEntityName,
                                     revision: DocRevision,
                                     whiskActionMetaData: WhiskActionMetaData)
-    extends ContainerMessage(transid) {
+  extends ContainerMessage(transid) {
   override def toJson: JsValue = ContainerDeletionMessage.serdes.write(this)
+
   override def serialize: String = toJson.compactPrint
 }
 
@@ -544,6 +575,7 @@ object ContainerDeletionMessage extends DefaultJsonProtocol {
 
 abstract class ContainerMessage(private val tid: TransactionId) extends Message {
   override val transid: TransactionId = tid
+
   override def serialize: String = ContainerMessage.serdes.write(this).compactPrint
 
   /** Serializes the message to JSON. */
@@ -569,18 +601,31 @@ object ContainerMessage extends DefaultJsonProtocol {
 }
 
 sealed trait ContainerCreationError
+
 object ContainerCreationError extends Enumeration {
+
   case object NoAvailableInvokersError extends ContainerCreationError
+
   case object NoAvailableResourceInvokersError extends ContainerCreationError
+
   case object ResourceNotEnoughError extends ContainerCreationError
+
   case object WhiskError extends ContainerCreationError
+
   case object UnknownError extends ContainerCreationError
+
   case object TimeoutError extends ContainerCreationError
+
   case object ShuttingDownError extends ContainerCreationError
+
   case object NonExecutableActionError extends ContainerCreationError
+
   case object DBFetchError extends ContainerCreationError
+
   case object BlackBoxError extends ContainerCreationError
+
   case object ZeroNamespaceLimit extends ContainerCreationError
+
   case object TooManyConcurrentRequests extends ContainerCreationError
 
   val whiskErrors: Set[ContainerCreationError] =
@@ -594,26 +639,27 @@ object ContainerCreationError extends Enumeration {
       TimeoutError,
       ZeroNamespaceLimit)
 
-  def fromName(name: String) = name.toUpperCase match {
-    case "NOAVAILABLEINVOKERSERROR"         => NoAvailableInvokersError
+  private def parse(name: String) = name.toUpperCase match {
+    case "NOAVAILABLEINVOKERSERROR" => NoAvailableInvokersError
     case "NOAVAILABLERESOURCEINVOKERSERROR" => NoAvailableResourceInvokersError
-    case "RESOURCENOTENOUGHERROR"           => ResourceNotEnoughError
-    case "NONEXECUTBLEACTIONERROR"          => NonExecutableActionError
-    case "DBFETCHERROR"                     => DBFetchError
-    case "WHISKERROR"                       => WhiskError
-    case "BLACKBOXERROR"                    => BlackBoxError
-    case "TIMEOUTERROR"                     => TimeoutError
-    case "ZERONAMESPACELIMIT"               => ZeroNamespaceLimit
-    case "TOOMANYCONCURRENTREQUESTS"        => TooManyConcurrentRequests
-    case "UNKNOWNERROR"                     => UnknownError
+    case "RESOURCENOTENOUGHERROR" => ResourceNotEnoughError
+    case "NONEXECUTBLEACTIONERROR" => NonExecutableActionError
+    case "DBFETCHERROR" => DBFetchError
+    case "WHISKERROR" => WhiskError
+    case "BLACKBOXERROR" => BlackBoxError
+    case "TIMEOUTERROR" => TimeoutError
+    case "ZERONAMESPACELIMIT" => ZeroNamespaceLimit
+    case "TOOMANYCONCURRENTREQUESTS" => TooManyConcurrentRequests
+    case "UNKNOWNERROR" => UnknownError
   }
 
   implicit val serds = new RootJsonFormat[ContainerCreationError] {
     override def write(error: ContainerCreationError): JsValue = JsString(error.toString)
+
     override def read(json: JsValue): ContainerCreationError =
       Try {
         val JsString(str) = json
-        ContainerCreationError.fromName(str.trim.toUpperCase)
+        ContainerCreationError.parse(str.trim.toUpperCase)
       } getOrElse {
         throw deserializationError("ContainerCreationError must be a valid string")
       }
@@ -632,7 +678,7 @@ case class ContainerCreationAckMessage(override val transid: TransactionId,
                                        retryCount: Int = 0,
                                        error: Option[ContainerCreationError] = None,
                                        reason: Option[String] = None)
-    extends Message {
+  extends Message {
 
   /**
    * Serializes message to string. Must be idempotent.
@@ -642,6 +688,7 @@ case class ContainerCreationAckMessage(override val transid: TransactionId,
 
 object ContainerCreationAckMessage extends DefaultJsonProtocol {
   def parse(msg: String): Try[ContainerCreationAckMessage] = Try(serdes.read(msg.parseJson))
+
   private implicit val fqnSerdes = FullyQualifiedEntityName.serdes
   private implicit val byteSizeSerdes = size.serdes
   implicit val serdes = jsonFormat12(ContainerCreationAckMessage.apply)
diff --git a/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala b/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala
index ded9a6a..258fc89 100644
--- a/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala
+++ b/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala
@@ -29,36 +29,55 @@ object SizeUnits extends Enumeration {
 
   sealed abstract class Unit() {
     def toBytes(n: Long): Long
+
     def toKBytes(n: Long): Long
+
     def toMBytes(n: Long): Long
+
     def toGBytes(n: Long): Long
   }
 
   case object BYTE extends Unit {
     def toBytes(n: Long): Long = n
+
     def toKBytes(n: Long): Long = n / 1024
+
     def toMBytes(n: Long): Long = n / 1024 / 1024
+
     def toGBytes(n: Long): Long = n / 1024 / 1024 / 1024
   }
+
   case object KB extends Unit {
     def toBytes(n: Long): Long = n * 1024
+
     def toKBytes(n: Long): Long = n
+
     def toMBytes(n: Long): Long = n / 1024
+
     def toGBytes(n: Long): Long = n / 1024 / 1024
 
   }
+
   case object MB extends Unit {
     def toBytes(n: Long): Long = n * 1024 * 1024
+
     def toKBytes(n: Long): Long = n * 1024
+
     def toMBytes(n: Long): Long = n
+
     def toGBytes(n: Long): Long = n / 1024
   }
+
   case object GB extends Unit {
     def toBytes(n: Long): Long = n * 1024 * 1024 * 1024
+
     def toKBytes(n: Long): Long = n * 1024 * 1024
+
     def toMBytes(n: Long): Long = n * 1024
+
     def toGBytes(n: Long): Long = n
   }
+
 }
 
 case class ByteSize(size: Long, unit: SizeUnits.Unit) extends Ordered[ByteSize] {
@@ -66,7 +85,9 @@ case class ByteSize(size: Long, unit: SizeUnits.Unit) extends Ordered[ByteSize]
   require(size >= 0, "a negative size of an object is not allowed.")
 
   def toBytes = unit.toBytes(size)
+
   def toKB = unit.toKBytes(size)
+
   def toMB = unit.toMBytes(size)
 
   def +(other: ByteSize): ByteSize = {
@@ -102,15 +123,15 @@ case class ByteSize(size: Long, unit: SizeUnits.Unit) extends Ordered[ByteSize]
 
   override def equals(that: Any): Boolean = that match {
     case t: ByteSize => compareTo(t) == 0
-    case _           => false
+    case _ => false
   }
 
   override def toString = {
     unit match {
       case SizeUnits.BYTE => s"$size B"
-      case SizeUnits.KB   => s"$size KB"
-      case SizeUnits.MB   => s"$size MB"
-      case SizeUnits.GB   => s"$size GB"
+      case SizeUnits.KB => s"$size KB"
+      case SizeUnits.MB => s"$size MB"
+      case SizeUnits.GB => s"$size GB"
     }
   }
 }
@@ -138,6 +159,7 @@ object ByteSize {
 }
 
 object size {
+
   implicit class SizeInt(n: Int) extends SizeConversion {
     def sizeIn(unit: SizeUnits.Unit): ByteSize = ByteSize(n, unit)
   }
@@ -163,24 +185,31 @@ object size {
   implicit val pureconfigReader =
     ConfigReader[ConfigValue].map(v => ByteSize(v.atKey("key").getBytes("key"), SizeUnits.BYTE))
 
-  implicit val serdes = new RootJsonFormat[ByteSize] {
+  protected[core] implicit val serdes = new RootJsonFormat[ByteSize] {
     def write(b: ByteSize) = JsString(b.toString)
 
     def read(value: JsValue): ByteSize = value match {
       case JsString(s) => ByteSize.fromString(s)
-      case _           => deserializationError(formatError)
+      case _ => deserializationError(formatError)
     }
   }
 }
 
 trait SizeConversion {
   def B = sizeIn(SizeUnits.BYTE)
+
   def KB = sizeIn(SizeUnits.KB)
+
   def MB = sizeIn(SizeUnits.MB)
+
   def GB: ByteSize = sizeIn(SizeUnits.GB)
+
   def bytes = B
+
   def kilobytes = KB
+
   def megabytes = MB
+
   def gigabytes: ByteSize = GB
 
   def sizeInBytes = sizeIn(SizeUnits.BYTE)
diff --git a/core/invoker/src/main/scala/org/apache/openwhisk/core/invoker/ContainerMessageConsumer.scala b/core/invoker/src/main/scala/org/apache/openwhisk/core/invoker/ContainerMessageConsumer.scala
new file mode 100644
index 0000000..05cddd8
--- /dev/null
+++ b/core/invoker/src/main/scala/org/apache/openwhisk/core/invoker/ContainerMessageConsumer.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.openwhisk.core.invoker
+
+import java.nio.charset.StandardCharsets
+
+import akka.actor.{ActorRef, ActorSystem, Props}
+import org.apache.kafka.clients.producer.RecordMetadata
+import org.apache.openwhisk.common.{GracefulShutdown, Logging, TransactionId}
+import org.apache.openwhisk.core.WarmUp.isWarmUpAction
+import org.apache.openwhisk.core.WhiskConfig
+import org.apache.openwhisk.core.connector.ContainerCreationError.DBFetchError
+import org.apache.openwhisk.core.connector._
+import org.apache.openwhisk.core.containerpool.v2.{CreationContainer, DeletionContainer}
+import org.apache.openwhisk.core.database.{
+  ArtifactStore,
+  DocumentTypeMismatchException,
+  DocumentUnreadable,
+  NoDocumentException
+}
+import org.apache.openwhisk.core.entity._
+import org.apache.openwhisk.http.Messages
+
+import scala.concurrent.duration._
+import scala.concurrent.{ExecutionContext, Future}
+import scala.util.{Failure, Success}
+
+class ContainerMessageConsumer(
+  invokerInstanceId: InvokerInstanceId,
+  containerPool: ActorRef,
+  entityStore: ArtifactStore[WhiskEntity],
+  config: WhiskConfig,
+  msgProvider: MessagingProvider,
+  longPollDuration: FiniteDuration,
+  maxPeek: Int,
+  sendAckToScheduler: (SchedulerInstanceId, ContainerCreationAckMessage) => Future[RecordMetadata])(
+  implicit actorSystem: ActorSystem,
+  executionContext: ExecutionContext,
+  logging: Logging) {
+
+  private val topic = s"${Invoker.topicPrefix}invoker${invokerInstanceId.toInt}"
+  private val consumer =
+    msgProvider.getConsumer(config, topic, topic, maxPeek, maxPollInterval = TimeLimit.MAX_DURATION + 1.minute)
+
+  private def handler(bytes: Array[Byte]): Future[Unit] = Future {
+    val raw = new String(bytes, StandardCharsets.UTF_8)
+    ContainerMessage.parse(raw) match {
+      case Success(creation: ContainerCreationMessage) if isWarmUpAction(creation.action) =>
+        logging.info(
+          this,
+          s"container creation message for ${creation.invocationNamespace}/${creation.action} is received (creationId: ${creation.creationId})")
+        feed ! MessageFeed.Processed
+
+      case Success(creation: ContainerCreationMessage) =>
+        implicit val transid: TransactionId = creation.transid
+        logging
+          .info(this, s"container creation message for ${creation.invocationNamespace}/${creation.action} is received")
+        WhiskAction
+          .get(entityStore, creation.action.toDocId, creation.revision, fromCache = true)
+          .map { action =>
+            containerPool ! CreationContainer(creation, action)
+            feed ! MessageFeed.Processed
+          }
+          .recover {
+            case t =>
+              val message = t match {
+                case _: NoDocumentException =>
+                  Messages.actionRemovedWhileInvoking
+                case _: DocumentTypeMismatchException | _: DocumentUnreadable =>
+                  Messages.actionMismatchWhileInvoking
+                case e: Throwable =>
+                  logging.error(this, s"An unknown DB connection error occurred while fetching an action: $e.")
+                  Messages.actionFetchErrorWhileInvoking
+              }
+              logging.error(
+                this,
+                s"failed to fetch action ${creation.invocationNamespace}/${creation.action}, error: $message (creationId: ${creation.creationId})")
+
+              val ack = ContainerCreationAckMessage(
+                creation.transid,
+                creation.creationId,
+                creation.invocationNamespace,
+                creation.action,
+                creation.revision,
+                creation.whiskActionMetaData,
+                invokerInstanceId,
+                creation.schedulerHost,
+                creation.rpcPort,
+                creation.retryCount,
+                Some(DBFetchError),
+                Some(message))
+              sendAckToScheduler(creation.rootSchedulerIndex, ack)
+              feed ! MessageFeed.Processed
+          }
+      case Success(deletion: ContainerDeletionMessage) =>
+        implicit val transid: TransactionId = deletion.transid
+        logging.info(this, s"deletion message for ${deletion.invocationNamespace}/${deletion.action} is received")
+        containerPool ! DeletionContainer(deletion)
+        feed ! MessageFeed.Processed
+      case Failure(t) =>
+        logging.error(this, s"Failed to parse $bytes, error: ${t.getMessage}")
+        feed ! MessageFeed.Processed
+
+      case _ =>
+        logging.error(this, s"Unexpected message received $raw")
+        feed ! MessageFeed.Processed
+    }
+  }
+
+  private val feed = actorSystem.actorOf(Props {
+    new MessageFeed("containerCreation", logging, consumer, maxPeek, longPollDuration, handler)
+  })
+
+  def close(): Unit = {
+    feed ! GracefulShutdown
+  }
+}
diff --git a/tests/src/test/scala/org/apache/openwhisk/core/invoker/test/ContainerMessageConsumerTests.scala b/tests/src/test/scala/org/apache/openwhisk/core/invoker/test/ContainerMessageConsumerTests.scala
new file mode 100644
index 0000000..5ceddfb
--- /dev/null
+++ b/tests/src/test/scala/org/apache/openwhisk/core/invoker/test/ContainerMessageConsumerTests.scala
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.openwhisk.core.invoker.test
+
+import java.nio.charset.StandardCharsets
+
+import akka.actor.ActorSystem
+import akka.stream.ActorMaterializer
+import akka.testkit.{TestKit, TestProbe}
+import common.StreamLogging
+import org.apache.kafka.clients.producer.RecordMetadata
+import org.apache.openwhisk.common.{Logging, TransactionId}
+import org.apache.openwhisk.core.{WarmUp, WhiskConfig}
+import org.apache.openwhisk.core.connector.ContainerCreationError._
+import org.apache.openwhisk.core.connector._
+import org.apache.openwhisk.core.connector.test.TestConnector
+import org.apache.openwhisk.core.containerpool.v2.CreationContainer
+import org.apache.openwhisk.core.database.test.DbUtils
+import org.apache.openwhisk.core.entity.ExecManifest.{ImageName, RuntimeManifest}
+import org.apache.openwhisk.core.entity._
+import org.apache.openwhisk.core.entity.size._
+import org.apache.openwhisk.core.entity.test.ExecHelpers
+import org.apache.openwhisk.core.invoker.ContainerMessageConsumer
+import org.apache.openwhisk.http.Messages
+import org.apache.openwhisk.utils.{retry => utilRetry}
+import org.junit.runner.RunWith
+import org.scalamock.scalatest.MockFactory
+import org.scalatest.junit.JUnitRunner
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FlatSpecLike, Matchers}
+
+import scala.concurrent.Future
+import scala.concurrent.duration._
+import scala.util.Try
+
+@RunWith(classOf[JUnitRunner])
+class ContainerMessageConsumerTests
+    extends TestKit(ActorSystem("ContainerMessageConsumer"))
+    with FlatSpecLike
+    with Matchers
+    with BeforeAndAfterEach
+    with BeforeAndAfterAll
+    with StreamLogging
+    with MockFactory
+    with DbUtils
+    with ExecHelpers {
+
+  implicit val actualActorSystem = system // Use system for duplicate system and actorSystem.
+  implicit val ec = actualActorSystem.dispatcher
+  implicit val materializer = ActorMaterializer()
+  implicit val transId = TransactionId.testing
+  implicit val creationId = CreationId.generate()
+
+  override def afterAll(): Unit = {
+    TestKit.shutdownActorSystem(system)
+    super.afterAll()
+  }
+
+  private val whiskConfig = new WhiskConfig(
+    Map(
+      WhiskConfig.actionInvokePerMinuteLimit -> null,
+      WhiskConfig.triggerFirePerMinuteLimit -> null,
+      WhiskConfig.actionInvokeConcurrentLimit -> null,
+      WhiskConfig.runtimesManifest -> null,
+      WhiskConfig.actionSequenceMaxLimit -> null))
+
+  private val entityStore = WhiskEntityStore.datastore()
+  private val producer = stub[MessageProducer]
+
+  private val defaultUserMemory: ByteSize = 1024.MB
+  private val invokerInstance = InvokerInstanceId(0, userMemory = defaultUserMemory)
+  private val schedulerInstanceId = SchedulerInstanceId("0")
+
+  private val invocationNamespace = EntityName("invocationSpace")
+
+  private val schedulerHost = "127.17.0.1"
+
+  private val rpcPort = 13001
+
+  override def afterEach(): Unit = {
+    cleanup()
+  }
+
+  private def fakeMessageProvider(consumer: TestConnector): MessagingProvider = {
+    new MessagingProvider {
+      override def getConsumer(
+        whiskConfig: WhiskConfig,
+        groupId: String,
+        topic: String,
+        maxPeek: Int,
+        maxPollInterval: FiniteDuration)(implicit logging: Logging, actorSystem: ActorSystem): MessageConsumer =
+        consumer
+
+      override def getProducer(config: WhiskConfig, maxRequestSize: Option[ByteSize])(
+        implicit logging: Logging,
+        actorSystem: ActorSystem): MessageProducer = consumer.getProducer()
+
+      override def ensureTopic(config: WhiskConfig,
+                               topic: String,
+                               topicConfig: String,
+                               maxMessageBytes: Option[ByteSize])(implicit logging: Logging): Try[Unit] = Try {}
+    }
+  }
+
+  def sendAckToScheduler(producer: MessageProducer)(schedulerInstanceId: SchedulerInstanceId,
+                                                    ackMessage: ContainerCreationAckMessage): Future[RecordMetadata] = {
+    val topic = s"creationAck${schedulerInstanceId.asString}"
+    producer.send(topic, ackMessage)
+  }
+
+  private def createAckMsg(creationMessage: ContainerCreationMessage,
+                           error: Option[ContainerCreationError],
+                           reason: Option[String]) = {
+    ContainerCreationAckMessage(
+      creationMessage.transid,
+      creationMessage.creationId,
+      creationMessage.invocationNamespace,
+      creationMessage.action,
+      creationMessage.revision,
+      creationMessage.whiskActionMetaData,
+      invokerInstance,
+      creationMessage.schedulerHost,
+      creationMessage.rpcPort,
+      creationMessage.retryCount,
+      error,
+      reason)
+  }
+
+  it should "forward ContainerCreationMessage to containerPool" in {
+    val pool = TestProbe()
+    val mockConsumer = new TestConnector("fakeTopic", 4, true)
+    val msgProvider = fakeMessageProvider(mockConsumer)
+
+    val consumer =
+      new ContainerMessageConsumer(
+        invokerInstance,
+        pool.ref,
+        entityStore,
+        whiskConfig,
+        msgProvider,
+        200.milliseconds,
+        500,
+        sendAckToScheduler(producer))
+
+    val exec = CodeExecAsString(RuntimeManifest("nodejs:10", ImageName("testImage")), "testCode", None)
+    val action =
+      WhiskAction(EntityPath("testns"), EntityName("testAction"), exec, limits = ActionLimits(TimeLimit(1.minute)))
+    put(entityStore, action)
+    val execMetadata =
+      CodeExecMetaDataAsString(exec.manifest, entryPoint = exec.entryPoint)
+    val actionMetadata =
+      WhiskActionMetaData(
+        action.namespace,
+        action.name,
+        execMetadata,
+        action.parameters,
+        action.limits,
+        action.version,
+        action.publish,
+        action.annotations)
+
+    val msg =
+      ContainerCreationMessage(
+        transId,
+        invocationNamespace.asString,
+        action.fullyQualifiedName(true),
+        DocRevision.empty,
+        actionMetadata,
+        schedulerInstanceId,
+        schedulerHost,
+        rpcPort,
+        creationId = creationId)
+
+    mockConsumer.send(msg)
+
+    pool.expectMsgPF() {
+      case CreationContainer(_, _) => true
+    }
+  }
+
+  it should "send ack(failed) to scheduler when failed to get action from DB " in {
+    val pool = TestProbe()
+    val creationConsumer = new TestConnector("creation", 4, true)
+    val msgProvider = fakeMessageProvider(creationConsumer)
+
+    val ackTopic = "ack"
+    val ackConsumer = new TestConnector(ackTopic, 4, true)
+
+    val consumer =
+      new ContainerMessageConsumer(
+        invokerInstance,
+        pool.ref,
+        entityStore,
+        whiskConfig,
+        msgProvider,
+        200.milliseconds,
+        500,
+        sendAckToScheduler(ackConsumer.getProducer()))
+
+    val exec = CodeExecAsString(RuntimeManifest("nodejs:10", ImageName("testImage")), "testCode", None)
+    val whiskAction =
+      WhiskAction(EntityPath("testns"), EntityName("testAction2"), exec, limits = ActionLimits(TimeLimit(1.minute)))
+    val execMetadata =
+      CodeExecMetaDataAsString(exec.manifest, entryPoint = exec.entryPoint)
+    val actionMetadata =
+      WhiskActionMetaData(
+        whiskAction.namespace,
+        whiskAction.name,
+        execMetadata,
+        whiskAction.parameters,
+        whiskAction.limits,
+        whiskAction.version,
+        whiskAction.publish,
+        whiskAction.annotations)
+
+    val creationMessage =
+      ContainerCreationMessage(
+        transId,
+        invocationNamespace.asString,
+        whiskAction.fullyQualifiedName(true),
+        DocRevision.empty,
+        actionMetadata,
+        schedulerInstanceId,
+        schedulerHost,
+        rpcPort,
+        creationId = creationId)
+
+    // action doesn't exist
+    val ackMessage = createAckMsg(creationMessage, Some(DBFetchError), Some(Messages.actionRemovedWhileInvoking))
+    creationConsumer.send(creationMessage)
+
+    within(5.seconds) {
+      utilRetry({
+        val buffer = ackConsumer.peek(50.millisecond)
+        buffer.size shouldBe 1
+        buffer.head._1 shouldBe ackTopic
+        new String(buffer.head._4, StandardCharsets.UTF_8) shouldBe ackMessage.serialize
+      }, 10, Some(500.millisecond))
+      pool.expectNoMessage(2.seconds)
+    }
+
+    // action exist but version mismatch
+    put(entityStore, whiskAction)
+    val actualCreationMessage = creationMessage.copy(revision = DocRevision("1-fake"))
+    val fetchErrorAckMessage =
+      createAckMsg(actualCreationMessage, Some(DBFetchError), Some(Messages.actionFetchErrorWhileInvoking))
+    creationConsumer.send(actualCreationMessage)
+
+    within(5.seconds) {
+      utilRetry({
+        val buffer2 = ackConsumer.peek(50.millisecond)
+        buffer2.size shouldBe 1
+        buffer2.head._1 shouldBe ackTopic
+        new String(buffer2.head._4, StandardCharsets.UTF_8) shouldBe fetchErrorAckMessage.serialize
+      }, 10, Some(500.millisecond))
+      pool.expectNoMessage(2.seconds)
+    }
+  }
+
+  it should "drop messages of warm-up action" in {
+    val pool = TestProbe()
+    val mockConsumer = new TestConnector("fakeTopic", 4, true)
+    val msgProvider = fakeMessageProvider(mockConsumer)
+
+    val consumer =
+      new ContainerMessageConsumer(
+        invokerInstance,
+        pool.ref,
+        entityStore,
+        whiskConfig,
+        msgProvider,
+        200.milliseconds,
+        500,
+        sendAckToScheduler(producer))
+
+    val exec = CodeExecAsString(RuntimeManifest("nodejs:10", ImageName("testImage")), "testCode", None)
+    val action =
+      WhiskAction(
+        WarmUp.warmUpAction.namespace.toPath,
+        WarmUp.warmUpAction.name,
+        exec,
+        limits = ActionLimits(TimeLimit(1.minute)))
+    val doc = put(entityStore, action)
+    val execMetadata =
+      CodeExecMetaDataAsString(exec.manifest, entryPoint = exec.entryPoint)
+
+    val actionMetadata =
+      WhiskActionMetaData(
+        action.namespace,
+        action.name,
+        execMetadata,
+        action.parameters,
+        action.limits,
+        action.version,
+        action.publish,
+        action.annotations)
+
+    val msg =
+      ContainerCreationMessage(
+        transId,
+        invocationNamespace.asString,
+        action.fullyQualifiedName(false),
+        DocRevision.empty,
+        actionMetadata,
+        schedulerInstanceId,
+        schedulerHost,
+        rpcPort,
+        creationId = creationId)
+
+    mockConsumer.send(msg)
+
+    pool.expectNoMessage(1.seconds)
+  }
+}