You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nlpcraft.apache.org by ar...@apache.org on 2022/02/01 01:53:57 UTC

[incubator-nlpcraft] branch NLPCRAFT-477 updated: WIP

This is an automated email from the ASF dual-hosted git repository.

aradzinski pushed a commit to branch NLPCRAFT-477
in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git


The following commit(s) were added to refs/heads/NLPCRAFT-477 by this push:
     new dcd0754  WIP
dcd0754 is described below

commit dcd0754a532fef1dc50a9c8836136d1bb8aaa719
Author: Aaron Radzinski <ar...@datalingvo.com>
AuthorDate: Mon Jan 31 17:53:49 2022 -0800

    WIP
---
 .../main/scala/org/apache/nlpcraft/NCEntity.java   |  10 +
 .../internal/conversation/NCConversation.scala     | 212 +++++++++++++++++++++
 .../nlpcraft/internal/intent/NCIDLEntity.scala     |   2 +-
 .../apache/nlpcraft/internal/util/NCUtils.scala    |  18 +-
 4 files changed, 239 insertions(+), 3 deletions(-)

diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCEntity.java b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCEntity.java
index 5069b03..c24b638 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCEntity.java
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCEntity.java
@@ -18,6 +18,7 @@
 package org.apache.nlpcraft;
 
 import java.util.*;
+import java.util.stream.Collectors;
 
 /**
  *
@@ -30,6 +31,15 @@ public interface NCEntity extends NCPropertyMap {
     List<NCToken> getTokens();
 
     /**
+     * Joins all tokens' text with trimming using space as a delimiter. This function does not cache the
+     * result and performs text construction on each call. Make sure to cache the result to avoid
+     * unnecessary parasitic workload if and when method {@link #getTokens()} does not change.
+     */
+    default String mkText() {
+        return getTokens().stream().map(s -> s.getText().trim()).collect(Collectors.joining(" ")).trim();
+    }
+
+    /**
      * Gets ID of the request this entity is part of.
      *
      * @return ID of the request this entity is part of.
diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversation.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversation.scala
new file mode 100644
index 0000000..e694bee
--- /dev/null
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversation.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nlpcraft.internal.conversation
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+import java.util.function.Predicate
+
+import com.typesafe.scalalogging.LazyLogging
+import org.apache.nlpcraft.*
+import org.apache.nlpcraft.internal.ascii.*
+import org.apache.nlpcraft.internal.util.*
+
+import scala.collection.mutable
+import scala.jdk.CollectionConverters.*
+
+/**
+  * An active conversation is an ordered set of utterances for the specific user and data model.
+  */
+case class NCConversation(
+    usrId: Long,
+    mdlId: String,
+    timeoutMs: Long,
+    maxDepth: Int
+) extends LazyLogging {
+    private final val data = new ConcurrentHashMap[String, Object]()
+
+    case class EntityHolder(entity: NCEntity, var entityTypeUsageTime: Long = 0)
+    case class ConversationItem(holders: mutable.ArrayBuffer[EntityHolder], reqId: String, tstamp: Long)
+
+    // Short-Term-Memory.
+    private val stm = mutable.ArrayBuffer.empty[ConversationItem]
+    private val lastEnts = mutable.ArrayBuffer.empty[Iterable[NCEntity]]
+
+    @volatile private var ctx: util.List[NCEntity] = new util.ArrayList[NCEntity]()
+    @volatile private var lastUpdateTstamp = NCUtils.nowUtcMs()
+    @volatile private var depth = 0
+
+    /**
+      *
+      */
+    private def squeezeEntities(): Unit =
+        require(Thread.holdsLock(stm))
+        stm --= stm.filter(_.holders.isEmpty)
+
+    /**
+      * Gets called on each input request for given user and model.
+      */
+    def updateEntities(): Unit =
+        val now = NCUtils.nowUtcMs()
+
+        stm.synchronized {
+            depth += 1
+
+            lazy val z = s"usrId=$usrId, mdlId=$mdlId"
+
+            // Conversation cleared by timeout or when there are too much unsuccessful requests.
+            if now - lastUpdateTstamp > timeoutMs then
+                stm.clear()
+                logger.trace(s"STM is reset by timeout [$z]")
+            else if depth > maxDepth then
+                stm.clear()
+                logger.trace(s"STM is reset after reaching max depth [$z]")
+            else
+                val minUsageTime = now - timeoutMs
+                val ents = lastEnts.flatten
+
+                for (item <- stm)
+                    val delHs =
+                        // Deleted by timeout for entity type or when an entity type used too many requests ago.
+                        item.holders.filter(h => h.entityTypeUsageTime < minUsageTime || !ents.contains(h.entity))
+
+                    if delHs.nonEmpty then
+                        item.holders --= delHs
+                        logger.trace(s"STM entity removed [$z, reqId=${item.reqId}]")
+                        stepLogEntity(delHs.toSeq.map(_.entity))
+
+                squeezeEntities()
+
+            lastUpdateTstamp = now
+            ctx = new util.ArrayList[NCEntity](stm.flatMap(_.holders.map(_.entity)).asJava)
+            ack()
+        }
+
+    /**
+      * Clears all entities from this conversation satisfying given predicate.
+      *
+      * @param p Java-side predicate.
+      */
+    def clearEntities(p: Predicate[NCEntity]): Unit =
+        stm.synchronized {
+            for (item <- stm) item.holders --= item.holders.filter(h => p.test(h.entity))
+            squeezeEntities()
+            ctx = ctx.asScala.filter(ent => !p.test(ent)).asJava
+        }
+
+        logger.trace(s"STM is cleared [usrId=$usrId, mdlId=$mdlId]")
+
+    /**
+      * Clears all entities from this conversation satisfying given predicate.
+      *
+      * @param p Scala-side predicate.
+      */
+    def clearEntities(p: NCEntity => Boolean): Unit =
+        clearEntities(new Predicate[NCEntity]:
+            override def test(t: NCEntity): Boolean = p(t)
+        )
+
+    /**
+      *
+      * @param ents
+      */
+    private def stepLogEntity(ents: Seq[NCEntity]): Unit =
+        for (ent <- ents) logger.trace(s"  +-- $ent")
+
+    /**
+      * Adds given entities to the conversation.
+      *
+      * @param reqId Server request ID.
+      * @param ents Entities to add to the conversation STM.
+      */
+    def addEntities(reqId: String, ents: Seq[NCEntity]): Unit =
+        stm.synchronized {
+            depth = 0
+            lastEnts += ents // Last used entities processing.
+
+            val delCnt = lastEnts.length - maxDepth
+            if delCnt > 0 then lastEnts.remove(0, delCnt)
+
+            val senEnts = ents.filter(_.getRequestId == reqId)
+            if senEnts.nonEmpty then
+                // Adds new conversation element.
+                stm += ConversationItem(
+                    mutable.ArrayBuffer.empty[EntityHolder] ++ senEnts.map(EntityHolder(_)),
+                    reqId,
+                    lastUpdateTstamp
+                )
+
+                logger.trace(s"Added new entities to STM [usrId=$usrId, mdlId=$mdlId, reqId=$reqId]")
+                stepLogEntity(ents)
+
+                val registered = mutable.HashSet.empty[Seq[String]]
+                for (item <- stm.reverse; (gs, hs) <- item.holders.groupBy(t => if (t.entity.getGroups != null) t.entity.getGroups.asScala else Seq.empty))
+                    val grps = gs.toSeq.sorted
+
+                    // Reversed iteration.
+                    // N : (A, B) -> registered.
+                    // N-1 : (C) -> registered.
+                    // N-2 : (A, B) or (A, B, X) etc -> deleted, because registered has less groups.
+                    registered.find(grps.containsSlice) match
+                        case Some(_) =>
+                            item.holders --= hs
+                            for (ent <- hs.map(_.entity)) logger.trace(s"STM entity overridden: $ent")
+
+                        case None => registered += grps
+
+                // Updates entity usage time.
+                stm.foreach(_.holders.filter(h => ents.contains(h.entity)).foreach(_.entityTypeUsageTime = lastUpdateTstamp))
+
+                squeezeEntities()
+        }
+
+    /**
+      * Prints out ASCII table for current STM.
+      */
+    private def ack(): Unit =
+        require(Thread.holdsLock(stm))
+
+        val z = s"mdlId=$mdlId, usrId=$usrId"
+
+        if ctx.isEmpty then logger.trace(s"STM is empty for [$z]")
+        else
+            val tbl = NCAsciiTable("Entity ID", "Groups", "Request ID")
+            ctx.asScala.foreach(ent => tbl += (
+                ent.getId,
+                ent.getGroups.asScala.mkString(", "),
+                ent.getRequestId
+            ))
+            logger.info(s"Current STM for [$z]:\n${tbl.toString()}")
+
+    /**
+      *
+      * @return
+      */
+    def getEntity: util.List[NCEntity] =
+        stm.synchronized {
+            val reqIds = ctx.asScala.map(_.getRequestId).distinct.zipWithIndex.toMap
+            val ents = ctx.asScala.groupBy(_.getRequestId).toSeq.sortBy(p => reqIds(p._1)).reverse.flatMap(_._2)
+
+            new util.ArrayList[NCEntity](ents.asJava)
+        }
+
+    /**
+      *
+      */
+    def getUserData: util.Map[String, Object] = data
+}
\ No newline at end of file
diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/NCIDLEntity.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/NCIDLEntity.scala
index f10c577..76f3c78 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/NCIDLEntity.scala
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/NCIDLEntity.scala
@@ -26,7 +26,7 @@ import scala.jdk.CollectionConverters.*
   * @param idx
   */
 class NCIDLEntity(ent: NCEntity, idx: Int):
-    private lazy val txt = ent.getTokens.asScala.map(_.getText).mkString(" ")
+    private lazy val txt = ent.mkText()
 
     def getImpl: NCEntity = ent
     def getText: String = txt
diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala
index 73b898a..478e9f3 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala
@@ -20,12 +20,15 @@ package org.apache.nlpcraft.internal.util
 import com.typesafe.scalalogging.*
 import org.apache.nlpcraft.*
 import com.google.gson.*
+
 import java.io.*
 import java.net.*
-import java.util.concurrent.{CopyOnWriteArrayList, ExecutorService, TimeUnit} // Avoids conflicts.
+import java.time.{ZoneId, Instant, ZonedDateTime}
+import java.util.concurrent.{CopyOnWriteArrayList, ExecutorService, TimeUnit}
 import java.util.regex.Pattern
 import java.util.zip.*
-import java.util.{Random, UUID}
+import java.util.{Random, TimeZone}
+
 import scala.annotation.tailrec
 import scala.collection.{IndexedSeq, Seq, mutable}
 import scala.concurrent.*
@@ -41,6 +44,7 @@ import scala.util.Using
 object NCUtils extends LazyLogging:
     final val NL = System getProperty "line.separator"
     private val RND = new Random()
+    private final val UTC = ZoneId.of("UTC")
     private val sysProps = new SystemProperties
     private final lazy val GSON = new GsonBuilder().setPrettyPrinting().disableHtmlEscaping().create()
 
@@ -199,6 +203,16 @@ object NCUtils extends LazyLogging:
         catch case e: Exception => E(s"Cannot extract JSON field '$field' from: '$json'", e)
 
     /**
+      * Gets now in UTC timezone.
+      */
+    def nowUtc(): ZonedDateTime = ZonedDateTime.now(UTC)
+
+    /**
+      * Gets now in UTC timezone in milliseconds representation.
+      */
+    def nowUtcMs(): Long = Instant.now().toEpochMilli
+
+    /**
       * Shortcut - current timestamp in milliseconds.
       */
     def now(): Long = System.currentTimeMillis()