You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nlpcraft.apache.org by ar...@apache.org on 2022/02/01 01:53:57 UTC
[incubator-nlpcraft] branch NLPCRAFT-477 updated: WIP
This is an automated email from the ASF dual-hosted git repository.
aradzinski pushed a commit to branch NLPCRAFT-477
in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git
The following commit(s) were added to refs/heads/NLPCRAFT-477 by this push:
new dcd0754 WIP
dcd0754 is described below
commit dcd0754a532fef1dc50a9c8836136d1bb8aaa719
Author: Aaron Radzinski <ar...@datalingvo.com>
AuthorDate: Mon Jan 31 17:53:49 2022 -0800
WIP
---
.../main/scala/org/apache/nlpcraft/NCEntity.java | 10 +
.../internal/conversation/NCConversation.scala | 212 +++++++++++++++++++++
.../nlpcraft/internal/intent/NCIDLEntity.scala | 2 +-
.../apache/nlpcraft/internal/util/NCUtils.scala | 18 +-
4 files changed, 239 insertions(+), 3 deletions(-)
diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCEntity.java b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCEntity.java
index 5069b03..c24b638 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCEntity.java
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCEntity.java
@@ -18,6 +18,7 @@
package org.apache.nlpcraft;
import java.util.*;
+import java.util.stream.Collectors;
/**
*
@@ -30,6 +31,15 @@ public interface NCEntity extends NCPropertyMap {
List<NCToken> getTokens();
/**
+ * Joins all tokens' text with trimming using space as a delimiter. This function does not cache the
+ * result and performs text construction on each call. Make sure to cache the result to avoid
+ * unnecessary parasitic workload if and when method {@link #getTokens()} does not change.
+ */
+ default String mkText() {
+ return getTokens().stream().map(s -> s.getText().trim()).collect(Collectors.joining(" ")).trim();
+ }
+
+ /**
* Gets ID of the request this entity is part of.
*
* @return ID of the request this entity is part of.
diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversation.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversation.scala
new file mode 100644
index 0000000..e694bee
--- /dev/null
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversation.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nlpcraft.internal.conversation
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+import java.util.function.Predicate
+
+import com.typesafe.scalalogging.LazyLogging
+import org.apache.nlpcraft.*
+import org.apache.nlpcraft.internal.ascii.*
+import org.apache.nlpcraft.internal.util.*
+
+import scala.collection.mutable
+import scala.jdk.CollectionConverters.*
+
+/**
+ * An active conversation is an ordered set of utterances for the specific user and data model.
+ */
+case class NCConversation(
+ usrId: Long,
+ mdlId: String,
+ timeoutMs: Long,
+ maxDepth: Int
+) extends LazyLogging {
+ private final val data = new ConcurrentHashMap[String, Object]()
+
+ case class EntityHolder(entity: NCEntity, var entityTypeUsageTime: Long = 0)
+ case class ConversationItem(holders: mutable.ArrayBuffer[EntityHolder], reqId: String, tstamp: Long)
+
+ // Short-Term-Memory.
+ private val stm = mutable.ArrayBuffer.empty[ConversationItem]
+ private val lastEnts = mutable.ArrayBuffer.empty[Iterable[NCEntity]]
+
+ @volatile private var ctx: util.List[NCEntity] = new util.ArrayList[NCEntity]()
+ @volatile private var lastUpdateTstamp = NCUtils.nowUtcMs()
+ @volatile private var depth = 0
+
+ /**
+ *
+ */
+ private def squeezeEntities(): Unit =
+ require(Thread.holdsLock(stm))
+ stm --= stm.filter(_.holders.isEmpty)
+
+ /**
+ * Gets called on each input request for given user and model.
+ */
+ def updateEntities(): Unit =
+ val now = NCUtils.nowUtcMs()
+
+ stm.synchronized {
+ depth += 1
+
+ lazy val z = s"usrId=$usrId, mdlId=$mdlId"
+
+ // Conversation cleared by timeout or when there are too much unsuccessful requests.
+ if now - lastUpdateTstamp > timeoutMs then
+ stm.clear()
+ logger.trace(s"STM is reset by timeout [$z]")
+ else if depth > maxDepth then
+ stm.clear()
+ logger.trace(s"STM is reset after reaching max depth [$z]")
+ else
+ val minUsageTime = now - timeoutMs
+ val ents = lastEnts.flatten
+
+ for (item <- stm)
+ val delHs =
+ // Deleted by timeout for entity type or when an entity type used too many requests ago.
+ item.holders.filter(h => h.entityTypeUsageTime < minUsageTime || !ents.contains(h.entity))
+
+ if delHs.nonEmpty then
+ item.holders --= delHs
+ logger.trace(s"STM entity removed [$z, reqId=${item.reqId}]")
+ stepLogEntity(delHs.toSeq.map(_.entity))
+
+ squeezeEntities()
+
+ lastUpdateTstamp = now
+ ctx = new util.ArrayList[NCEntity](stm.flatMap(_.holders.map(_.entity)).asJava)
+ ack()
+ }
+
+ /**
+ * Clears all entities from this conversation satisfying given predicate.
+ *
+ * @param p Java-side predicate.
+ */
+ def clearEntities(p: Predicate[NCEntity]): Unit =
+ stm.synchronized {
+ for (item <- stm) item.holders --= item.holders.filter(h => p.test(h.entity))
+ squeezeEntities()
+ ctx = ctx.asScala.filter(ent => !p.test(ent)).asJava
+ }
+
+ logger.trace(s"STM is cleared [usrId=$usrId, mdlId=$mdlId]")
+
+ /**
+ * Clears all entities from this conversation satisfying given predicate.
+ *
+ * @param p Scala-side predicate.
+ */
+ def clearEntities(p: NCEntity => Boolean): Unit =
+ clearEntities(new Predicate[NCEntity]:
+ override def test(t: NCEntity): Boolean = p(t)
+ )
+
+ /**
+ *
+ * @param ents
+ */
+ private def stepLogEntity(ents: Seq[NCEntity]): Unit =
+ for (ent <- ents) logger.trace(s" +-- $ent")
+
+ /**
+ * Adds given entities to the conversation.
+ *
+ * @param reqId Server request ID.
+ * @param ents Entities to add to the conversation STM.
+ */
+ def addEntities(reqId: String, ents: Seq[NCEntity]): Unit =
+ stm.synchronized {
+ depth = 0
+ lastEnts += ents // Last used entities processing.
+
+ val delCnt = lastEnts.length - maxDepth
+ if delCnt > 0 then lastEnts.remove(0, delCnt)
+
+ val senEnts = ents.filter(_.getRequestId == reqId)
+ if senEnts.nonEmpty then
+ // Adds new conversation element.
+ stm += ConversationItem(
+ mutable.ArrayBuffer.empty[EntityHolder] ++ senEnts.map(EntityHolder(_)),
+ reqId,
+ lastUpdateTstamp
+ )
+
+ logger.trace(s"Added new entities to STM [usrId=$usrId, mdlId=$mdlId, reqId=$reqId]")
+ stepLogEntity(ents)
+
+ val registered = mutable.HashSet.empty[Seq[String]]
+ for (item <- stm.reverse; (gs, hs) <- item.holders.groupBy(t => if (t.entity.getGroups != null) t.entity.getGroups.asScala else Seq.empty))
+ val grps = gs.toSeq.sorted
+
+ // Reversed iteration.
+ // N : (A, B) -> registered.
+ // N-1 : (C) -> registered.
+ // N-2 : (A, B) or (A, B, X) etc -> deleted, because registered has less groups.
+ registered.find(grps.containsSlice) match
+ case Some(_) =>
+ item.holders --= hs
+ for (ent <- hs.map(_.entity)) logger.trace(s"STM entity overridden: $ent")
+
+ case None => registered += grps
+
+ // Updates entity usage time.
+ stm.foreach(_.holders.filter(h => ents.contains(h.entity)).foreach(_.entityTypeUsageTime = lastUpdateTstamp))
+
+ squeezeEntities()
+ }
+
+ /**
+ * Prints out ASCII table for current STM.
+ */
+ private def ack(): Unit =
+ require(Thread.holdsLock(stm))
+
+ val z = s"mdlId=$mdlId, usrId=$usrId"
+
+ if ctx.isEmpty then logger.trace(s"STM is empty for [$z]")
+ else
+ val tbl = NCAsciiTable("Entity ID", "Groups", "Request ID")
+ ctx.asScala.foreach(ent => tbl += (
+ ent.getId,
+ ent.getGroups.asScala.mkString(", "),
+ ent.getRequestId
+ ))
+ logger.info(s"Current STM for [$z]:\n${tbl.toString()}")
+
+ /**
+ *
+ * @return
+ */
+ def getEntity: util.List[NCEntity] =
+ stm.synchronized {
+ val reqIds = ctx.asScala.map(_.getRequestId).distinct.zipWithIndex.toMap
+ val ents = ctx.asScala.groupBy(_.getRequestId).toSeq.sortBy(p => reqIds(p._1)).reverse.flatMap(_._2)
+
+ new util.ArrayList[NCEntity](ents.asJava)
+ }
+
+ /**
+ *
+ */
+ def getUserData: util.Map[String, Object] = data
+}
\ No newline at end of file
diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/NCIDLEntity.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/NCIDLEntity.scala
index f10c577..76f3c78 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/NCIDLEntity.scala
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/NCIDLEntity.scala
@@ -26,7 +26,7 @@ import scala.jdk.CollectionConverters.*
* @param idx
*/
class NCIDLEntity(ent: NCEntity, idx: Int):
- private lazy val txt = ent.getTokens.asScala.map(_.getText).mkString(" ")
+ private lazy val txt = ent.mkText()
def getImpl: NCEntity = ent
def getText: String = txt
diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala
index 73b898a..478e9f3 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala
@@ -20,12 +20,15 @@ package org.apache.nlpcraft.internal.util
import com.typesafe.scalalogging.*
import org.apache.nlpcraft.*
import com.google.gson.*
+
import java.io.*
import java.net.*
-import java.util.concurrent.{CopyOnWriteArrayList, ExecutorService, TimeUnit} // Avoids conflicts.
+import java.time.{ZoneId, Instant, ZonedDateTime}
+import java.util.concurrent.{CopyOnWriteArrayList, ExecutorService, TimeUnit}
import java.util.regex.Pattern
import java.util.zip.*
-import java.util.{Random, UUID}
+import java.util.{Random, TimeZone}
+
import scala.annotation.tailrec
import scala.collection.{IndexedSeq, Seq, mutable}
import scala.concurrent.*
@@ -41,6 +44,7 @@ import scala.util.Using
object NCUtils extends LazyLogging:
final val NL = System getProperty "line.separator"
private val RND = new Random()
+ private final val UTC = ZoneId.of("UTC")
private val sysProps = new SystemProperties
private final lazy val GSON = new GsonBuilder().setPrettyPrinting().disableHtmlEscaping().create()
@@ -199,6 +203,16 @@ object NCUtils extends LazyLogging:
catch case e: Exception => E(s"Cannot extract JSON field '$field' from: '$json'", e)
/**
+ * Gets now in UTC timezone.
+ */
+ def nowUtc(): ZonedDateTime = ZonedDateTime.now(UTC)
+
+ /**
+ * Gets now in UTC timezone in milliseconds representation.
+ */
+ def nowUtcMs(): Long = Instant.now().toEpochMilli
+
+ /**
* Shortcut - current timestamp in milliseconds.
*/
def now(): Long = System.currentTimeMillis()