You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nlpcraft.apache.org by se...@apache.org on 2022/09/20 10:13:02 UTC

[incubator-nlpcraft] 01/01: Fixes.

This is an automated email from the ASF dual-hosted git repository.

sergeykamov pushed a commit to branch NLPCRAFT-515
in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git

commit a6609af66ddd8cc8d2e1511d1f4a9425866e08ec
Author: Sergey Kamov <sk...@gmail.com>
AuthorDate: Tue Sep 20 13:12:53 2022 +0300

    Fixes.
---
 .../main/scala/org/apache/nlpcraft/NCModel.scala   | 21 +-------
 .../scala/org/apache/nlpcraft/NCModelClient.scala  | 15 ++++++
 .../intent/matcher/NCIntentSolverManager.scala     |  2 +-
 .../internal/impl/NCModelCallbacksSpec.scala       |  1 -
 .../internal/impl/NCModelClientSpec4.scala         | 62 ++++++++++++++++++++++
 5 files changed, 79 insertions(+), 22 deletions(-)

diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCModel.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCModel.scala
index 28c108fb..091f3329 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCModel.scala
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCModel.scala
@@ -55,26 +55,7 @@ trait NCModel:
     def getPipeline: NCPipeline
 
     /**
-      * A callback to accept or reject a parsed variant. This callback is called before any other callbacks at the
-      * beginning of the processing pipeline, and it is called for each parsed variant.
-      *
-      * Note that a given input query can have one or more possible different parsing variants. Depending on model
-      * configuration an input query can produce hundreds or even thousands of parsing variants that can significantly
-      * slow down the overall processing. This method allows to filter out unnecessary parsing variants based on
-      * variety of user-defined factors like number of entities, presence of a particular entity in the variant, etc.
-      *
-      * By default, this method accepts all variants (returns `true`).
-      *
-      * NOTE: this the pipeline has its own mechanism to filter variants via [[NCPipeline.getVariantFilter]] method and
-      * class [[NCVariantFilter]].
-      *
-      * @param vrn A parsing variant to accept or reject.
-      * @return `True` to accept variant for further processing, `false` otherwise.
-      * @see [[NCVariantFilter]]
-      */
-    def onVariant(vrn: NCVariant) = true
-
-    /**
+      * TODO: drop link to onVariant
       * A callback that is called when a fully assembled query context is ready. This callback is called after
       * all {@link # onVariant ( NCVariant )} callbacks are called but before any {@link # onMatchedIntent ( NCIntentMatch )} are
       * called, i.e. right before the intent matching is performed. It's called always once per input query processing.
diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCModelClient.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCModelClient.scala
index 83631668..fa0d996f 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCModelClient.scala
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCModelClient.scala
@@ -48,6 +48,8 @@ class NCModelClient(mdl: NCModel) extends LazyLogging, AutoCloseable:
     private val plMgr = NCModelPipelineManager(mdl.getConfig, mdl.getPipeline)
     private val intentsMgr = NCIntentSolverManager(dlgMgr, convMgr, intents.map(p => p.intent -> p.function).toMap)
 
+    private var closed = false
+
     init()
 
     /**
@@ -73,6 +75,11 @@ class NCModelClient(mdl: NCModel) extends LazyLogging, AutoCloseable:
         dlgMgr.start()
         plMgr.start()
 
+    /**
+      *
+      */
+    private def checkClosed(): Unit = if closed then throw new IllegalStateException("Client is already closed.")
+
     /**
       *
       * @param txt
@@ -85,6 +92,8 @@ class NCModelClient(mdl: NCModel) extends LazyLogging, AutoCloseable:
         require(data != null, "Data cannot be null.")
         require(usrId != null, "User id cannot be null.")
 
+        checkClosed()
+
         val plData = plMgr.prepare(txt, data, usrId)
 
         val userId = plData.request.getUserId
@@ -137,6 +146,7 @@ class NCModelClient(mdl: NCModel) extends LazyLogging, AutoCloseable:
       */
     def clearStm(usrId: String): Unit =
         require(usrId != null, "User id cannot be null.")
+        checkClosed()
         convMgr.getConversation(usrId).clear(_ => true)
 
     /**
@@ -148,6 +158,7 @@ class NCModelClient(mdl: NCModel) extends LazyLogging, AutoCloseable:
     def clearStm(usrId: String, filter: NCEntity => Boolean): Unit =
         require(usrId != null, "User id cannot be null.")
         require(filter != null, "Filter cannot be null.")
+        checkClosed()
         convMgr.getConversation(usrId).clear(filter)
 
     /**
@@ -157,6 +168,7 @@ class NCModelClient(mdl: NCModel) extends LazyLogging, AutoCloseable:
       */
     def clearDialog(usrId: String): Unit =
         require(usrId != null, "User id cannot be null.")
+        checkClosed()
         dlgMgr.clear(usrId)
 
     /**
@@ -168,12 +180,15 @@ class NCModelClient(mdl: NCModel) extends LazyLogging, AutoCloseable:
     def clearDialog(usrId: String, filter: NCDialogFlowItem => Boolean): Unit =
         require(usrId != null, "User ID cannot be null.")
         require(usrId != null, "Filter cannot be null.")
+        checkClosed()
         dlgMgr.clear(usrId, (i: NCDialogFlowItem) => filter(i))
 
     /**
       * Closes this client releasing its associated resources.
       */
     override def close(): Unit =
+        checkClosed()
+        closed = true
         plMgr.close()
         dlgMgr.close()
         convMgr.close()
diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/matcher/NCIntentSolverManager.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/matcher/NCIntentSolverManager.scala
index b41a302f..8fe00ee2 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/matcher/NCIntentSolverManager.scala
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/matcher/NCIntentSolverManager.scala
@@ -272,7 +272,7 @@ class NCIntentSolverManager(
 
         // Find all matches across all intents and sentence variants.
         for (
-            (vrn, vrnIdx) <- ctx.getVariants.zipWithIndex if mdl.onVariant(vrn);
+            (vrn, vrnIdx) <- ctx.getVariants.zipWithIndex;
             ents = vrn.getEntities;
             varEntsGroups = ents.filter(t => t.getGroups != null && t.getGroups.nonEmpty).map(_.getGroups);
             (intent, callback) <- intents
diff --git a/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/impl/NCModelCallbacksSpec.scala b/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/impl/NCModelCallbacksSpec.scala
index 938fa56a..d7385e98 100644
--- a/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/impl/NCModelCallbacksSpec.scala
+++ b/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/impl/NCModelCallbacksSpec.scala
@@ -55,7 +55,6 @@ class NCModelCallbacksSpec extends AnyFunSuite:
                 else RESULT_INTENT
 
             override def onMatchedIntent(ctx: NCContext, im: NCIntentMatch): Boolean = getOrElse(MatchFalse, false, true)
-            override def onVariant(vrn: NCVariant): Boolean = getOrElse(VariantFalse, false, true)
             override def onContext(ctx: NCContext): Option[NCResult] = getOrElse(ContextNotNull, Some(RESULT_CONTEXT), None)
             override def onResult(ctx: NCContext, im: NCIntentMatch, res: NCResult): Option[NCResult] = getOrElse(ResultNotNull, Some(RESULT_RESULT), None)
             override def onRejection(ctx: NCContext, im: Option[NCIntentMatch], e: NCRejection): Option[NCResult] = getOrElse(RejectionNotNull, Some(RESULT_REJECTION), None)
diff --git a/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/impl/NCModelClientSpec4.scala b/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/impl/NCModelClientSpec4.scala
new file mode 100644
index 00000000..43b2430c
--- /dev/null
+++ b/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/impl/NCModelClientSpec4.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nlpcraft.internal.impl
+
+import org.apache.nlpcraft.*
+import org.apache.nlpcraft.annotations.NCIntent
+import org.apache.nlpcraft.nlp.parsers.NCNLPEntityParser
+import org.apache.nlpcraft.nlp.util.*
+import org.scalatest.funsuite.AnyFunSuite
+
+/**
+  * 
+  */
+class NCModelClientSpec4 extends AnyFunSuite:
+    test("test") {
+        val pl = mkEnPipeline
+
+        //  For intents matching, we have to add at least one entity parser.
+        pl.entParsers += new NCNLPEntityParser
+
+        val mdl: NCModel = new NCModelAdapter(CFG, pl) :
+            @NCIntent("intent=i term(any)={true}")
+            def onMatch(ctx: NCContext, im: NCIntentMatch): NCResult = TEST_RESULT
+
+        val client = new NCModelClient(mdl)
+
+        val allCalls = Seq(
+            () => client.ask("test", "userId"),
+            () => client.debugAsk("test", "userId", false),
+            () => client.clearStm("userId", _ => true),
+            () => client.clearStm("userId"),
+            () => client.clearDialog("userId"),
+            () => client.clearDialog("userId", _ => true)
+        )
+
+        for (call <- allCalls) call.apply()
+
+        client.close()
+
+        for (call <- allCalls ++ Seq(() => client.close()))
+            try
+                call.apply()
+                require(false)
+            catch case e: IllegalStateException => println(s"Expected: ${e.getLocalizedMessage}")
+    }
+
+