You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@openwhisk.apache.org by ni...@apache.org on 2021/05/28 01:14:30 UTC
[openwhisk] branch master updated: [New Scheduler] Add container
message consumer (#5111)
This is an automated email from the ASF dual-hosted git repository.
ningyougang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/openwhisk.git
The following commit(s) were added to refs/heads/master by this push:
new b818c3b [New Scheduler] Add container message consumer (#5111)
b818c3b is described below
commit b818c3b3e8bd3fa9ac7742d1b8c051ec09b76ae2
Author: Seonghyun Oh <oh...@navercorp.com>
AuthorDate: Fri May 28 10:14:11 2021 +0900
[New Scheduler] Add container message consumer (#5111)
* Add container message consumer
* Reformat code
* Fix test case error
Co-authored-by: ning.yougang <ni...@navercorp.com>
---
.../apache/openwhisk/core/connector/Message.scala | 135 ++++++---
.../org/apache/openwhisk/core/entity/Size.scala | 41 ++-
.../core/invoker/ContainerMessageConsumer.scala | 132 +++++++++
.../test/ContainerMessageConsumerTests.scala | 328 +++++++++++++++++++++
4 files changed, 586 insertions(+), 50 deletions(-)
diff --git a/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala b/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala
index 88f4f11..ba05c17 100644
--- a/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala
+++ b/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala
@@ -60,7 +60,7 @@ case class ActivationMessage(override val transid: TransactionId,
lockedArgs: Map[String, String] = Map.empty,
cause: Option[ActivationId] = None,
traceContext: Option[Map[String, String]] = None)
- extends Message {
+ extends Message {
override def serialize = ActivationMessage.serdes.write(this).compactPrint
@@ -78,6 +78,7 @@ case class ActivationMessage(override val transid: TransactionId,
*/
abstract class AcknowledegmentMessage(private val tid: TransactionId) extends Message {
override val transid: TransactionId = tid
+
override def serialize: String = AcknowledegmentMessage.serdes.write(this).compactPrint
/** Pithy descriptor for logging. */
@@ -115,17 +116,23 @@ abstract class AcknowledegmentMessage(private val tid: TransactionId) extends Me
* The constructor is private so that callers must use the more restrictive constructors which ensure the respose is always
* Right when this message is created.
*/
-case class CombinedCompletionAndResultMessage private (override val transid: TransactionId,
- response: Either[ActivationId, WhiskActivation],
- override val isSystemError: Option[Boolean],
- instance: InstanceId)
- extends AcknowledegmentMessage(transid) {
+case class CombinedCompletionAndResultMessage private(override val transid: TransactionId,
+ response: Either[ActivationId, WhiskActivation],
+ override val isSystemError: Option[Boolean],
+ instance: InstanceId)
+ extends AcknowledegmentMessage(transid) {
override def messageType = "combined"
+
override def result = Some(response)
+
override def isSlotFree = Some(instance)
+
override def activationId = response.fold(identity, _.activationId)
+
override def toJson = CombinedCompletionAndResultMessage.serdes.write(this)
+
override def shrink = copy(response = response.flatMap(a => Left(a.activationId)))
+
override def toString = activationId.asString
}
@@ -135,16 +142,21 @@ case class CombinedCompletionAndResultMessage private (override val transid: Tra
* phase notification to the load balancer where an invoker first sends a `ResultMessage` and later sends the
* `CompletionMessage`.
*/
-case class CompletionMessage private (override val transid: TransactionId,
- override val activationId: ActivationId,
- override val isSystemError: Option[Boolean],
- instance: InstanceId)
- extends AcknowledegmentMessage(transid) {
+case class CompletionMessage private(override val transid: TransactionId,
+ override val activationId: ActivationId,
+ override val isSystemError: Option[Boolean],
+ instance: InstanceId)
+ extends AcknowledegmentMessage(transid) {
override def messageType = "completion"
+
override def result = None
+
override def isSlotFree = Some(instance)
+
override def toJson = CompletionMessage.serdes.write(this)
+
override def shrink = this
+
override def toString = activationId.asString
}
@@ -156,15 +168,22 @@ case class CompletionMessage private (override val transid: TransactionId,
* The constructor is private so that callers must use the more restrictive constructors which ensure the respose is always
* Right when this message is created.
*/
-case class ResultMessage private (override val transid: TransactionId, response: Either[ActivationId, WhiskActivation])
- extends AcknowledegmentMessage(transid) {
+case class ResultMessage private(override val transid: TransactionId, response: Either[ActivationId, WhiskActivation])
+ extends AcknowledegmentMessage(transid) {
override def messageType = "result"
+
override def result = Some(response)
+
override def isSlotFree = None
+
override def isSystemError = response.fold(_ => None, a => Some(a.response.isWhiskError))
+
override def activationId = response.fold(identity, _.activationId)
+
override def toJson = ResultMessage.serdes.write(this)
+
override def shrink = copy(response = response.flatMap(a => Left(a.activationId)))
+
override def toString = activationId.asString
}
@@ -234,7 +253,7 @@ object AcknowledegmentMessage extends DefaultJsonProtocol {
Left(value.convertTo[ActivationId])
case _: JsObject => Right(value.convertTo[WhiskActivation])
- case _ => deserializationError("could not read ResultMessage")
+ case _ => deserializationError("could not read ResultMessage")
}
}
@@ -265,6 +284,7 @@ case class PingMessage(instance: InvokerInstanceId) extends Message {
object PingMessage extends DefaultJsonProtocol {
def parse(msg: String) = Try(serdes.read(msg.parseJson))
+
implicit val serdes = jsonFormat(PingMessage.apply _, "name")
}
@@ -276,7 +296,7 @@ object EventMessageBody extends DefaultJsonProtocol {
implicit val format = new JsonFormat[EventMessageBody] {
def write(eventMessageBody: EventMessageBody) = eventMessageBody match {
- case m: Metric => m.toJson
+ case m: Metric => m.toJson
case a: Activation => a.toJson
}
@@ -301,9 +321,11 @@ case class Activation(name: String,
causedBy: Option[String],
size: Option[Int] = None,
userDefinedStatusCode: Option[Int] = None)
- extends EventMessageBody {
+ extends EventMessageBody {
val typeName = Activation.typeName
+
override def serialize = toJson.compactPrint
+
def entityPath: FullyQualifiedEntityName = EntityPath(name).toFullyQualifiedEntityName
def toJson = Activation.activationFormat.write(this)
@@ -327,12 +349,12 @@ object Activation extends DefaultJsonProtocol {
private implicit val durationFormat = new RootJsonFormat[Duration] {
override def write(obj: Duration): JsValue = obj match {
case o if o.isFinite => JsNumber(o.toMillis)
- case _ => JsNumber.zero
+ case _ => JsNumber.zero
}
override def read(json: JsValue): Duration = json match {
case JsNumber(n) if n <= 0 => Duration.Zero
- case JsNumber(n) => toDuration(n.longValue)
+ case JsNumber(n) => toDuration(n.longValue)
}
}
@@ -352,7 +374,7 @@ object Activation extends DefaultJsonProtocol {
"size",
"userDefinedStatusCode")
- /** Get "StatusCode" from result response set by action developer **/
+ /** Get "StatusCode" from result response set by action developer * */
def userDefinedStatusCode(result: Option[JsValue]): Option[Int] = {
val statusCode = JsHelpers
.getFieldPath(result.get.asJsObject, ERROR_FIELD, "statusCode")
@@ -394,13 +416,17 @@ object Activation extends DefaultJsonProtocol {
case class Metric(metricName: String, metricValue: Long) extends EventMessageBody {
val typeName = "Metric"
+
override def serialize = toJson.compactPrint
+
def toJson = Metric.metricFormat.write(this).asJsObject
}
object Metric extends DefaultJsonProtocol {
val typeName = "Metric"
+
def parse(msg: String) = Try(metricFormat.read(msg.parseJson))
+
implicit val metricFormat = jsonFormat(Metric.apply _, "metricName", "metricValue")
}
@@ -411,7 +437,7 @@ case class EventMessage(source: String,
userId: UUID,
eventType: String,
timestamp: Long = System.currentTimeMillis())
- extends Message {
+ extends Message {
override def serialize = EventMessage.format.write(this).compactPrint
}
@@ -434,7 +460,7 @@ case class InvokerResourceMessage(status: String,
inProgressMemory: Long,
tags: Seq[String],
dedicatedNamespaces: Seq[String])
- extends Message {
+ extends Message {
/**
* Serializes message to string. Must be idempotent.
@@ -444,6 +470,7 @@ case class InvokerResourceMessage(status: String,
object InvokerResourceMessage extends DefaultJsonProtocol {
def parse(msg: String): Try[InvokerResourceMessage] = Try(serdes.read(msg.parseJson))
+
implicit val serdes =
jsonFormat(
InvokerResourceMessage.apply _,
@@ -462,23 +489,25 @@ object InvokerResourceMessage extends DefaultJsonProtocol {
*
* [
* ...
- * {
- * "data": "RunningData",
- * "fqn": "whisk.system/elasticsearch/status-alarm@0.0.2",
- * "invocationNamespace": "style95",
- * "status": "Running",
- * "waitingActivation": 1
- * },
+ * {
+ * "data": "RunningData",
+ * "fqn": "whisk.system/elasticsearch/status-alarm@0.0.2",
+ * "invocationNamespace": "style95",
+ * "status": "Running",
+ * "waitingActivation": 1
+ * },
* ...
* ]
*/
object StatusQuery
+
case class StatusData(invocationNamespace: String, fqn: String, waitingActivation: Int, status: String, data: String)
- extends Message {
+ extends Message {
override def serialize: String = StatusData.serdes.write(this).compactPrint
}
+
object StatusData extends DefaultJsonProtocol {
implicit val serdes =
@@ -495,9 +524,10 @@ case class ContainerCreationMessage(override val transid: TransactionId,
rpcPort: Int,
retryCount: Int = 0,
creationId: CreationId = CreationId.generate())
- extends ContainerMessage(transid) {
+ extends ContainerMessage(transid) {
override def toJson: JsValue = ContainerCreationMessage.serdes.write(this)
+
override def serialize: String = toJson.compactPrint
}
@@ -526,8 +556,9 @@ case class ContainerDeletionMessage(override val transid: TransactionId,
action: FullyQualifiedEntityName,
revision: DocRevision,
whiskActionMetaData: WhiskActionMetaData)
- extends ContainerMessage(transid) {
+ extends ContainerMessage(transid) {
override def toJson: JsValue = ContainerDeletionMessage.serdes.write(this)
+
override def serialize: String = toJson.compactPrint
}
@@ -544,6 +575,7 @@ object ContainerDeletionMessage extends DefaultJsonProtocol {
abstract class ContainerMessage(private val tid: TransactionId) extends Message {
override val transid: TransactionId = tid
+
override def serialize: String = ContainerMessage.serdes.write(this).compactPrint
/** Serializes the message to JSON. */
@@ -569,18 +601,31 @@ object ContainerMessage extends DefaultJsonProtocol {
}
sealed trait ContainerCreationError
+
object ContainerCreationError extends Enumeration {
+
case object NoAvailableInvokersError extends ContainerCreationError
+
case object NoAvailableResourceInvokersError extends ContainerCreationError
+
case object ResourceNotEnoughError extends ContainerCreationError
+
case object WhiskError extends ContainerCreationError
+
case object UnknownError extends ContainerCreationError
+
case object TimeoutError extends ContainerCreationError
+
case object ShuttingDownError extends ContainerCreationError
+
case object NonExecutableActionError extends ContainerCreationError
+
case object DBFetchError extends ContainerCreationError
+
case object BlackBoxError extends ContainerCreationError
+
case object ZeroNamespaceLimit extends ContainerCreationError
+
case object TooManyConcurrentRequests extends ContainerCreationError
val whiskErrors: Set[ContainerCreationError] =
@@ -594,26 +639,27 @@ object ContainerCreationError extends Enumeration {
TimeoutError,
ZeroNamespaceLimit)
- def fromName(name: String) = name.toUpperCase match {
- case "NOAVAILABLEINVOKERSERROR" => NoAvailableInvokersError
+ private def parse(name: String) = name.toUpperCase match {
+ case "NOAVAILABLEINVOKERSERROR" => NoAvailableInvokersError
case "NOAVAILABLERESOURCEINVOKERSERROR" => NoAvailableResourceInvokersError
- case "RESOURCENOTENOUGHERROR" => ResourceNotEnoughError
- case "NONEXECUTBLEACTIONERROR" => NonExecutableActionError
- case "DBFETCHERROR" => DBFetchError
- case "WHISKERROR" => WhiskError
- case "BLACKBOXERROR" => BlackBoxError
- case "TIMEOUTERROR" => TimeoutError
- case "ZERONAMESPACELIMIT" => ZeroNamespaceLimit
- case "TOOMANYCONCURRENTREQUESTS" => TooManyConcurrentRequests
- case "UNKNOWNERROR" => UnknownError
+ case "RESOURCENOTENOUGHERROR" => ResourceNotEnoughError
+ case "NONEXECUTBLEACTIONERROR" => NonExecutableActionError
+ case "DBFETCHERROR" => DBFetchError
+ case "WHISKERROR" => WhiskError
+ case "BLACKBOXERROR" => BlackBoxError
+ case "TIMEOUTERROR" => TimeoutError
+ case "ZERONAMESPACELIMIT" => ZeroNamespaceLimit
+ case "TOOMANYCONCURRENTREQUESTS" => TooManyConcurrentRequests
+ case "UNKNOWNERROR" => UnknownError
}
implicit val serds = new RootJsonFormat[ContainerCreationError] {
override def write(error: ContainerCreationError): JsValue = JsString(error.toString)
+
override def read(json: JsValue): ContainerCreationError =
Try {
val JsString(str) = json
- ContainerCreationError.fromName(str.trim.toUpperCase)
+ ContainerCreationError.parse(str.trim.toUpperCase)
} getOrElse {
throw deserializationError("ContainerCreationError must be a valid string")
}
@@ -632,7 +678,7 @@ case class ContainerCreationAckMessage(override val transid: TransactionId,
retryCount: Int = 0,
error: Option[ContainerCreationError] = None,
reason: Option[String] = None)
- extends Message {
+ extends Message {
/**
* Serializes message to string. Must be idempotent.
@@ -642,6 +688,7 @@ case class ContainerCreationAckMessage(override val transid: TransactionId,
object ContainerCreationAckMessage extends DefaultJsonProtocol {
def parse(msg: String): Try[ContainerCreationAckMessage] = Try(serdes.read(msg.parseJson))
+
private implicit val fqnSerdes = FullyQualifiedEntityName.serdes
private implicit val byteSizeSerdes = size.serdes
implicit val serdes = jsonFormat12(ContainerCreationAckMessage.apply)
diff --git a/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala b/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala
index ded9a6a..258fc89 100644
--- a/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala
+++ b/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala
@@ -29,36 +29,55 @@ object SizeUnits extends Enumeration {
sealed abstract class Unit() {
def toBytes(n: Long): Long
+
def toKBytes(n: Long): Long
+
def toMBytes(n: Long): Long
+
def toGBytes(n: Long): Long
}
case object BYTE extends Unit {
def toBytes(n: Long): Long = n
+
def toKBytes(n: Long): Long = n / 1024
+
def toMBytes(n: Long): Long = n / 1024 / 1024
+
def toGBytes(n: Long): Long = n / 1024 / 1024 / 1024
}
+
case object KB extends Unit {
def toBytes(n: Long): Long = n * 1024
+
def toKBytes(n: Long): Long = n
+
def toMBytes(n: Long): Long = n / 1024
+
def toGBytes(n: Long): Long = n / 1024 / 1024
}
+
case object MB extends Unit {
def toBytes(n: Long): Long = n * 1024 * 1024
+
def toKBytes(n: Long): Long = n * 1024
+
def toMBytes(n: Long): Long = n
+
def toGBytes(n: Long): Long = n / 1024
}
+
case object GB extends Unit {
def toBytes(n: Long): Long = n * 1024 * 1024 * 1024
+
def toKBytes(n: Long): Long = n * 1024 * 1024
+
def toMBytes(n: Long): Long = n * 1024
+
def toGBytes(n: Long): Long = n
}
+
}
case class ByteSize(size: Long, unit: SizeUnits.Unit) extends Ordered[ByteSize] {
@@ -66,7 +85,9 @@ case class ByteSize(size: Long, unit: SizeUnits.Unit) extends Ordered[ByteSize]
require(size >= 0, "a negative size of an object is not allowed.")
def toBytes = unit.toBytes(size)
+
def toKB = unit.toKBytes(size)
+
def toMB = unit.toMBytes(size)
def +(other: ByteSize): ByteSize = {
@@ -102,15 +123,15 @@ case class ByteSize(size: Long, unit: SizeUnits.Unit) extends Ordered[ByteSize]
override def equals(that: Any): Boolean = that match {
case t: ByteSize => compareTo(t) == 0
- case _ => false
+ case _ => false
}
override def toString = {
unit match {
case SizeUnits.BYTE => s"$size B"
- case SizeUnits.KB => s"$size KB"
- case SizeUnits.MB => s"$size MB"
- case SizeUnits.GB => s"$size GB"
+ case SizeUnits.KB => s"$size KB"
+ case SizeUnits.MB => s"$size MB"
+ case SizeUnits.GB => s"$size GB"
}
}
}
@@ -138,6 +159,7 @@ object ByteSize {
}
object size {
+
implicit class SizeInt(n: Int) extends SizeConversion {
def sizeIn(unit: SizeUnits.Unit): ByteSize = ByteSize(n, unit)
}
@@ -163,24 +185,31 @@ object size {
implicit val pureconfigReader =
ConfigReader[ConfigValue].map(v => ByteSize(v.atKey("key").getBytes("key"), SizeUnits.BYTE))
- implicit val serdes = new RootJsonFormat[ByteSize] {
+ protected[core] implicit val serdes = new RootJsonFormat[ByteSize] {
def write(b: ByteSize) = JsString(b.toString)
def read(value: JsValue): ByteSize = value match {
case JsString(s) => ByteSize.fromString(s)
- case _ => deserializationError(formatError)
+ case _ => deserializationError(formatError)
}
}
}
trait SizeConversion {
def B = sizeIn(SizeUnits.BYTE)
+
def KB = sizeIn(SizeUnits.KB)
+
def MB = sizeIn(SizeUnits.MB)
+
def GB: ByteSize = sizeIn(SizeUnits.GB)
+
def bytes = B
+
def kilobytes = KB
+
def megabytes = MB
+
def gigabytes: ByteSize = GB
def sizeInBytes = sizeIn(SizeUnits.BYTE)
diff --git a/core/invoker/src/main/scala/org/apache/openwhisk/core/invoker/ContainerMessageConsumer.scala b/core/invoker/src/main/scala/org/apache/openwhisk/core/invoker/ContainerMessageConsumer.scala
new file mode 100644
index 0000000..05cddd8
--- /dev/null
+++ b/core/invoker/src/main/scala/org/apache/openwhisk/core/invoker/ContainerMessageConsumer.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.openwhisk.core.invoker
+
+import java.nio.charset.StandardCharsets
+
+import akka.actor.{ActorRef, ActorSystem, Props}
+import org.apache.kafka.clients.producer.RecordMetadata
+import org.apache.openwhisk.common.{GracefulShutdown, Logging, TransactionId}
+import org.apache.openwhisk.core.WarmUp.isWarmUpAction
+import org.apache.openwhisk.core.WhiskConfig
+import org.apache.openwhisk.core.connector.ContainerCreationError.DBFetchError
+import org.apache.openwhisk.core.connector._
+import org.apache.openwhisk.core.containerpool.v2.{CreationContainer, DeletionContainer}
+import org.apache.openwhisk.core.database.{
+ ArtifactStore,
+ DocumentTypeMismatchException,
+ DocumentUnreadable,
+ NoDocumentException
+}
+import org.apache.openwhisk.core.entity._
+import org.apache.openwhisk.http.Messages
+
+import scala.concurrent.duration._
+import scala.concurrent.{ExecutionContext, Future}
+import scala.util.{Failure, Success}
+
+class ContainerMessageConsumer(
+ invokerInstanceId: InvokerInstanceId,
+ containerPool: ActorRef,
+ entityStore: ArtifactStore[WhiskEntity],
+ config: WhiskConfig,
+ msgProvider: MessagingProvider,
+ longPollDuration: FiniteDuration,
+ maxPeek: Int,
+ sendAckToScheduler: (SchedulerInstanceId, ContainerCreationAckMessage) => Future[RecordMetadata])(
+ implicit actorSystem: ActorSystem,
+ executionContext: ExecutionContext,
+ logging: Logging) {
+
+ private val topic = s"${Invoker.topicPrefix}invoker${invokerInstanceId.toInt}"
+ private val consumer =
+ msgProvider.getConsumer(config, topic, topic, maxPeek, maxPollInterval = TimeLimit.MAX_DURATION + 1.minute)
+
+ private def handler(bytes: Array[Byte]): Future[Unit] = Future {
+ val raw = new String(bytes, StandardCharsets.UTF_8)
+ ContainerMessage.parse(raw) match {
+ case Success(creation: ContainerCreationMessage) if isWarmUpAction(creation.action) =>
+ logging.info(
+ this,
+ s"container creation message for ${creation.invocationNamespace}/${creation.action} is received (creationId: ${creation.creationId})")
+ feed ! MessageFeed.Processed
+
+ case Success(creation: ContainerCreationMessage) =>
+ implicit val transid: TransactionId = creation.transid
+ logging
+ .info(this, s"container creation message for ${creation.invocationNamespace}/${creation.action} is received")
+ WhiskAction
+ .get(entityStore, creation.action.toDocId, creation.revision, fromCache = true)
+ .map { action =>
+ containerPool ! CreationContainer(creation, action)
+ feed ! MessageFeed.Processed
+ }
+ .recover {
+ case t =>
+ val message = t match {
+ case _: NoDocumentException =>
+ Messages.actionRemovedWhileInvoking
+ case _: DocumentTypeMismatchException | _: DocumentUnreadable =>
+ Messages.actionMismatchWhileInvoking
+ case e: Throwable =>
+ logging.error(this, s"An unknown DB connection error occurred while fetching an action: $e.")
+ Messages.actionFetchErrorWhileInvoking
+ }
+ logging.error(
+ this,
+ s"failed to fetch action ${creation.invocationNamespace}/${creation.action}, error: $message (creationId: ${creation.creationId})")
+
+ val ack = ContainerCreationAckMessage(
+ creation.transid,
+ creation.creationId,
+ creation.invocationNamespace,
+ creation.action,
+ creation.revision,
+ creation.whiskActionMetaData,
+ invokerInstanceId,
+ creation.schedulerHost,
+ creation.rpcPort,
+ creation.retryCount,
+ Some(DBFetchError),
+ Some(message))
+ sendAckToScheduler(creation.rootSchedulerIndex, ack)
+ feed ! MessageFeed.Processed
+ }
+ case Success(deletion: ContainerDeletionMessage) =>
+ implicit val transid: TransactionId = deletion.transid
+ logging.info(this, s"deletion message for ${deletion.invocationNamespace}/${deletion.action} is received")
+ containerPool ! DeletionContainer(deletion)
+ feed ! MessageFeed.Processed
+ case Failure(t) =>
+ logging.error(this, s"Failed to parse $bytes, error: ${t.getMessage}")
+ feed ! MessageFeed.Processed
+
+ case _ =>
+ logging.error(this, s"Unexpected message received $raw")
+ feed ! MessageFeed.Processed
+ }
+ }
+
+ private val feed = actorSystem.actorOf(Props {
+ new MessageFeed("containerCreation", logging, consumer, maxPeek, longPollDuration, handler)
+ })
+
+ def close(): Unit = {
+ feed ! GracefulShutdown
+ }
+}
diff --git a/tests/src/test/scala/org/apache/openwhisk/core/invoker/test/ContainerMessageConsumerTests.scala b/tests/src/test/scala/org/apache/openwhisk/core/invoker/test/ContainerMessageConsumerTests.scala
new file mode 100644
index 0000000..5ceddfb
--- /dev/null
+++ b/tests/src/test/scala/org/apache/openwhisk/core/invoker/test/ContainerMessageConsumerTests.scala
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.openwhisk.core.invoker.test
+
+import java.nio.charset.StandardCharsets
+
+import akka.actor.ActorSystem
+import akka.stream.ActorMaterializer
+import akka.testkit.{TestKit, TestProbe}
+import common.StreamLogging
+import org.apache.kafka.clients.producer.RecordMetadata
+import org.apache.openwhisk.common.{Logging, TransactionId}
+import org.apache.openwhisk.core.{WarmUp, WhiskConfig}
+import org.apache.openwhisk.core.connector.ContainerCreationError._
+import org.apache.openwhisk.core.connector._
+import org.apache.openwhisk.core.connector.test.TestConnector
+import org.apache.openwhisk.core.containerpool.v2.CreationContainer
+import org.apache.openwhisk.core.database.test.DbUtils
+import org.apache.openwhisk.core.entity.ExecManifest.{ImageName, RuntimeManifest}
+import org.apache.openwhisk.core.entity._
+import org.apache.openwhisk.core.entity.size._
+import org.apache.openwhisk.core.entity.test.ExecHelpers
+import org.apache.openwhisk.core.invoker.ContainerMessageConsumer
+import org.apache.openwhisk.http.Messages
+import org.apache.openwhisk.utils.{retry => utilRetry}
+import org.junit.runner.RunWith
+import org.scalamock.scalatest.MockFactory
+import org.scalatest.junit.JUnitRunner
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FlatSpecLike, Matchers}
+
+import scala.concurrent.Future
+import scala.concurrent.duration._
+import scala.util.Try
+
+@RunWith(classOf[JUnitRunner])
+class ContainerMessageConsumerTests
+ extends TestKit(ActorSystem("ContainerMessageConsumer"))
+ with FlatSpecLike
+ with Matchers
+ with BeforeAndAfterEach
+ with BeforeAndAfterAll
+ with StreamLogging
+ with MockFactory
+ with DbUtils
+ with ExecHelpers {
+
+ implicit val actualActorSystem = system // Use system for duplicate system and actorSystem.
+ implicit val ec = actualActorSystem.dispatcher
+ implicit val materializer = ActorMaterializer()
+ implicit val transId = TransactionId.testing
+ implicit val creationId = CreationId.generate()
+
+ override def afterAll(): Unit = {
+ TestKit.shutdownActorSystem(system)
+ super.afterAll()
+ }
+
+ private val whiskConfig = new WhiskConfig(
+ Map(
+ WhiskConfig.actionInvokePerMinuteLimit -> null,
+ WhiskConfig.triggerFirePerMinuteLimit -> null,
+ WhiskConfig.actionInvokeConcurrentLimit -> null,
+ WhiskConfig.runtimesManifest -> null,
+ WhiskConfig.actionSequenceMaxLimit -> null))
+
+ private val entityStore = WhiskEntityStore.datastore()
+ private val producer = stub[MessageProducer]
+
+ private val defaultUserMemory: ByteSize = 1024.MB
+ private val invokerInstance = InvokerInstanceId(0, userMemory = defaultUserMemory)
+ private val schedulerInstanceId = SchedulerInstanceId("0")
+
+ private val invocationNamespace = EntityName("invocationSpace")
+
+ private val schedulerHost = "127.17.0.1"
+
+ private val rpcPort = 13001
+
+ override def afterEach(): Unit = {
+ cleanup()
+ }
+
+ private def fakeMessageProvider(consumer: TestConnector): MessagingProvider = {
+ new MessagingProvider {
+ override def getConsumer(
+ whiskConfig: WhiskConfig,
+ groupId: String,
+ topic: String,
+ maxPeek: Int,
+ maxPollInterval: FiniteDuration)(implicit logging: Logging, actorSystem: ActorSystem): MessageConsumer =
+ consumer
+
+ override def getProducer(config: WhiskConfig, maxRequestSize: Option[ByteSize])(
+ implicit logging: Logging,
+ actorSystem: ActorSystem): MessageProducer = consumer.getProducer()
+
+ override def ensureTopic(config: WhiskConfig,
+ topic: String,
+ topicConfig: String,
+ maxMessageBytes: Option[ByteSize])(implicit logging: Logging): Try[Unit] = Try {}
+ }
+ }
+
+ def sendAckToScheduler(producer: MessageProducer)(schedulerInstanceId: SchedulerInstanceId,
+ ackMessage: ContainerCreationAckMessage): Future[RecordMetadata] = {
+ val topic = s"creationAck${schedulerInstanceId.asString}"
+ producer.send(topic, ackMessage)
+ }
+
+ private def createAckMsg(creationMessage: ContainerCreationMessage,
+ error: Option[ContainerCreationError],
+ reason: Option[String]) = {
+ ContainerCreationAckMessage(
+ creationMessage.transid,
+ creationMessage.creationId,
+ creationMessage.invocationNamespace,
+ creationMessage.action,
+ creationMessage.revision,
+ creationMessage.whiskActionMetaData,
+ invokerInstance,
+ creationMessage.schedulerHost,
+ creationMessage.rpcPort,
+ creationMessage.retryCount,
+ error,
+ reason)
+ }
+
+ it should "forward ContainerCreationMessage to containerPool" in {
+ val pool = TestProbe()
+ val mockConsumer = new TestConnector("fakeTopic", 4, true)
+ val msgProvider = fakeMessageProvider(mockConsumer)
+
+ val consumer =
+ new ContainerMessageConsumer(
+ invokerInstance,
+ pool.ref,
+ entityStore,
+ whiskConfig,
+ msgProvider,
+ 200.milliseconds,
+ 500,
+ sendAckToScheduler(producer))
+
+ val exec = CodeExecAsString(RuntimeManifest("nodejs:10", ImageName("testImage")), "testCode", None)
+ val action =
+ WhiskAction(EntityPath("testns"), EntityName("testAction"), exec, limits = ActionLimits(TimeLimit(1.minute)))
+ put(entityStore, action)
+ val execMetadata =
+ CodeExecMetaDataAsString(exec.manifest, entryPoint = exec.entryPoint)
+ val actionMetadata =
+ WhiskActionMetaData(
+ action.namespace,
+ action.name,
+ execMetadata,
+ action.parameters,
+ action.limits,
+ action.version,
+ action.publish,
+ action.annotations)
+
+ val msg =
+ ContainerCreationMessage(
+ transId,
+ invocationNamespace.asString,
+ action.fullyQualifiedName(true),
+ DocRevision.empty,
+ actionMetadata,
+ schedulerInstanceId,
+ schedulerHost,
+ rpcPort,
+ creationId = creationId)
+
+ mockConsumer.send(msg)
+
+ pool.expectMsgPF() {
+ case CreationContainer(_, _) => true
+ }
+ }
+
+ it should "send ack(failed) to scheduler when failed to get action from DB " in {
+ val pool = TestProbe()
+ val creationConsumer = new TestConnector("creation", 4, true)
+ val msgProvider = fakeMessageProvider(creationConsumer)
+
+ val ackTopic = "ack"
+ val ackConsumer = new TestConnector(ackTopic, 4, true)
+
+ val consumer =
+ new ContainerMessageConsumer(
+ invokerInstance,
+ pool.ref,
+ entityStore,
+ whiskConfig,
+ msgProvider,
+ 200.milliseconds,
+ 500,
+ sendAckToScheduler(ackConsumer.getProducer()))
+
+ val exec = CodeExecAsString(RuntimeManifest("nodejs:10", ImageName("testImage")), "testCode", None)
+ val whiskAction =
+ WhiskAction(EntityPath("testns"), EntityName("testAction2"), exec, limits = ActionLimits(TimeLimit(1.minute)))
+ val execMetadata =
+ CodeExecMetaDataAsString(exec.manifest, entryPoint = exec.entryPoint)
+ val actionMetadata =
+ WhiskActionMetaData(
+ whiskAction.namespace,
+ whiskAction.name,
+ execMetadata,
+ whiskAction.parameters,
+ whiskAction.limits,
+ whiskAction.version,
+ whiskAction.publish,
+ whiskAction.annotations)
+
+ val creationMessage =
+ ContainerCreationMessage(
+ transId,
+ invocationNamespace.asString,
+ whiskAction.fullyQualifiedName(true),
+ DocRevision.empty,
+ actionMetadata,
+ schedulerInstanceId,
+ schedulerHost,
+ rpcPort,
+ creationId = creationId)
+
+ // action doesn't exist
+ val ackMessage = createAckMsg(creationMessage, Some(DBFetchError), Some(Messages.actionRemovedWhileInvoking))
+ creationConsumer.send(creationMessage)
+
+ within(5.seconds) {
+ utilRetry({
+ val buffer = ackConsumer.peek(50.millisecond)
+ buffer.size shouldBe 1
+ buffer.head._1 shouldBe ackTopic
+ new String(buffer.head._4, StandardCharsets.UTF_8) shouldBe ackMessage.serialize
+ }, 10, Some(500.millisecond))
+ pool.expectNoMessage(2.seconds)
+ }
+
+ // action exist but version mismatch
+ put(entityStore, whiskAction)
+ val actualCreationMessage = creationMessage.copy(revision = DocRevision("1-fake"))
+ val fetchErrorAckMessage =
+ createAckMsg(actualCreationMessage, Some(DBFetchError), Some(Messages.actionFetchErrorWhileInvoking))
+ creationConsumer.send(actualCreationMessage)
+
+ within(5.seconds) {
+ utilRetry({
+ val buffer2 = ackConsumer.peek(50.millisecond)
+ buffer2.size shouldBe 1
+ buffer2.head._1 shouldBe ackTopic
+ new String(buffer2.head._4, StandardCharsets.UTF_8) shouldBe fetchErrorAckMessage.serialize
+ }, 10, Some(500.millisecond))
+ pool.expectNoMessage(2.seconds)
+ }
+ }
+
+ it should "drop messages of warm-up action" in {
+ val pool = TestProbe()
+ val mockConsumer = new TestConnector("fakeTopic", 4, true)
+ val msgProvider = fakeMessageProvider(mockConsumer)
+
+ val consumer =
+ new ContainerMessageConsumer(
+ invokerInstance,
+ pool.ref,
+ entityStore,
+ whiskConfig,
+ msgProvider,
+ 200.milliseconds,
+ 500,
+ sendAckToScheduler(producer))
+
+ val exec = CodeExecAsString(RuntimeManifest("nodejs:10", ImageName("testImage")), "testCode", None)
+ val action =
+ WhiskAction(
+ WarmUp.warmUpAction.namespace.toPath,
+ WarmUp.warmUpAction.name,
+ exec,
+ limits = ActionLimits(TimeLimit(1.minute)))
+ val doc = put(entityStore, action)
+ val execMetadata =
+ CodeExecMetaDataAsString(exec.manifest, entryPoint = exec.entryPoint)
+
+ val actionMetadata =
+ WhiskActionMetaData(
+ action.namespace,
+ action.name,
+ execMetadata,
+ action.parameters,
+ action.limits,
+ action.version,
+ action.publish,
+ action.annotations)
+
+ val msg =
+ ContainerCreationMessage(
+ transId,
+ invocationNamespace.asString,
+ action.fullyQualifiedName(false),
+ DocRevision.empty,
+ actionMetadata,
+ schedulerInstanceId,
+ schedulerHost,
+ rpcPort,
+ creationId = creationId)
+
+ mockConsumer.send(msg)
+
+ pool.expectNoMessage(1.seconds)
+ }
+}