You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nlpcraft.apache.org by se...@apache.org on 2020/06/12 08:55:13 UTC

[incubator-nlpcraft] 02/02: WIP.

This is an automated email from the ASF dual-hosted git repository.

sergeykamov pushed a commit to branch NLPCRAFT-41
in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git

commit 7479b7f0076b396b2f4ecc1f72f1ecf7c1d0d0e6
Author: Sergey Kamov <se...@apache.org>
AuthorDate: Fri Jun 12 11:55:01 2020 +0300

    WIP.
---
 .../tools/suggestions/NCSuggestionsGenerator.scala | 52 +++++++++++-----------
 1 file changed, 25 insertions(+), 27 deletions(-)

diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/model/tools/suggestions/NCSuggestionsGenerator.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/model/tools/suggestions/NCSuggestionsGenerator.scala
index a1fddc9..17d70de 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/model/tools/suggestions/NCSuggestionsGenerator.scala
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/model/tools/suggestions/NCSuggestionsGenerator.scala
@@ -41,10 +41,10 @@ import scala.collection._
 case class ParametersHolder(modelPath: String, url: String, limit: Int, minScore: Double, debug: Boolean)
 
 object NCSuggestionsGeneratorImpl {
-    case class Suggestion(word: String, index1: Double, index2: Double, index3: Double)
+    case class Suggestion(word: String, totalScore: Double, fastTextScore: Double, bertScore: Double)
 
     case class RequestData(sentence: String, example: String, elementId: String, index: Int)
-    case class RestRequest(sentences: java.util.List[java.util.List[Any]], limit: Int, simple: Boolean = false)
+    case class RestRequest(sentences: java.util.List[java.util.List[Any]], limit: Int, min_score: Double, simple: Boolean = false)
     case class Word(word: String, stem: String) {
         require(!word.contains(" "), s"Word cannot contains spaces: $word")
         require(
@@ -82,11 +82,11 @@ object NCSuggestionsGeneratorImpl {
                         else
                             p.asScala.tail.map(p ⇒
                                 Suggestion(
-                                    word = p.get(2).asInstanceOf[String],
-                                    index1 = p.get(0).asInstanceOf[Double],
-                                    index2 = p.get(1).asInstanceOf[Double],
-                                    index3 = p.get(3).asInstanceOf[Double]
-                                ),
+                                    word = p.get(0).asInstanceOf[String],
+                                    totalScore = p.get(1).asInstanceOf[Double],
+                                    fastTextScore = p.get(2).asInstanceOf[Double],
+                                    bertScore = p.get(3).asInstanceOf[Double]
+                                )
                             )
                     )
 
@@ -198,7 +198,9 @@ object NCSuggestionsGeneratorImpl {
                             GSON.toJson(
                                 RestRequest(
                                     sentences = batch.map(p ⇒ Seq(p.sentence, p.index).asJava).asJava,
-                                    limit = data.limit)
+                                    min_score = data.minScore,
+                                    limit = data.limit
+                                )
                             ),
                             "UTF-8"
                         )
@@ -242,22 +244,23 @@ object NCSuggestionsGeneratorImpl {
 
         val filteredSuggs =
             allSuggs.asScala.map {
-                case (elemId, elemSuggs) ⇒ elemId → elemSuggs.asScala.filter(_.index1 >= data.minScore)
+                case (elemId, elemSuggs) ⇒ elemId → elemSuggs.asScala.filter(_.totalScore >= data.minScore)
             }.filter(_._2.nonEmpty)
 
-        val avgScores = filteredSuggs.map { case (elemId, suggs) ⇒ elemId → (suggs.map(_.index1).sum / suggs.size) }
+        val avgScores = filteredSuggs.map { case (elemId, suggs) ⇒ elemId → (suggs.map(_.totalScore).sum / suggs.size) }
         val counts = filteredSuggs.map { case (elemId, suggs) ⇒ elemId → suggs.size }
 
         val tbl = NCAsciiTable()
 
+        // TODO: which columns do we need?
         tbl #= (
             "Element",
             "Suggestion",
-            "Summary factor",
+            "Summary Score",
             "Count",
-            "F1",
-            "F2",
-            "F3"
+            "ContextWord Score",
+            "FastText Score",
+            "Bert Score"
         )
 
         filteredSuggs.
@@ -267,7 +270,7 @@ object NCSuggestionsGeneratorImpl {
                     groupBy { case (_, stem) ⇒ stem }.
                     filter { case (stem, _) ⇒ !allSynsStems.contains(stem) }.
                     map { case (_, group) ⇒
-                        val seq = group.map { case (sugg, _) ⇒ sugg }.sortBy(-_.index1)
+                        val seq = group.map { case (sugg, _) ⇒ sugg }.sortBy(-_.totalScore)
 
                         // Drops repeated.
                         (seq.head, seq.length)
@@ -277,7 +280,7 @@ object NCSuggestionsGeneratorImpl {
                 val normFactor = seq.map(_._2).sum.toDouble / seq.size / avgScores(elemId)
 
                 seq.
-                    map { case (sugg, cnt) ⇒ (sugg, cnt, sugg.index1 * normFactor * cnt.toDouble / counts(elemId)) }.
+                    map { case (sugg, cnt) ⇒ (sugg, cnt, sugg.totalScore * normFactor * cnt.toDouble / counts(elemId)) }.
                     sortBy { case (_, _, cumFactor) ⇒ -cumFactor }.
                     zipWithIndex.
                     foreach { case ((sugg, cnt, cumFactor), sugIdx) ⇒
@@ -288,9 +291,9 @@ object NCSuggestionsGeneratorImpl {
                             sugg.word,
                             f(cumFactor),
                             cnt,
-                            f(sugg.index1),
-                            f(sugg.index2),
-                            f(sugg.index3)
+                            f(sugg.totalScore),
+                            f(sugg.fastTextScore),
+                            f(sugg.bertScore)
                         )
                     }
             }
@@ -325,7 +328,7 @@ object NCSuggestionsGeneratorImpl {
 
 object NCSuggestionsGenerator extends App {
     private lazy val DFLT_URL: String = "http://localhost:5000/suggestions"
-    private lazy val DFLT_LIMIT: Int = 10 // TODO: add scoreLimit
+    private lazy val DFLT_LIMIT: Int = 10
     private lazy val DFLT_MIN_SCORE: Double = 0
     private lazy val DFLT_DEBUG: Boolean = false
 
@@ -335,13 +338,8 @@ object NCSuggestionsGenerator extends App {
       */
     private def errorExit(msg: String = null): Unit = {
         if (msg != null)
-            System.err.println(
-                s"""
-                   |ERROR:
-                   |    $msg""".stripMargin
-            )
-
-        if (msg == null)
+            System.err.println(s"ERROR: $msg")
+        else
             System.err.println(
                 s"""
                    |NAME: