You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nlpcraft.apache.org by if...@apache.org on 2020/06/05 04:23:12 UTC

[incubator-nlpcraft] 02/02: NLPCRAFT-11: After-review changes

This is an automated email from the ASF dual-hosted git repository.

ifropc pushed a commit to branch NLPCRAFT-67
in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git

commit b25635e756517b6eeb8c48c3562d0517b4cfb24f
Author: Ifropc <if...@protonmail.com>
AuthorDate: Sat May 16 20:01:19 2020 -0700

    NLPCRAFT-11: After-review changes
    
    - Added validation, limit argument
---
 enricher/README.md                   |  8 +++---
 enricher/bertft/bertft.py            | 48 +++++++++++++++---------------------
 enricher/bin/install_dependencies.sh |  2 +-
 enricher/bin/predict.sh              |  2 +-
 enricher/bin/py_requirements         |  2 +-
 enricher/bin/start_server.sh         |  2 +-
 enricher/server.py                   | 42 ++++++++++++++++++++++++-------
 7 files changed, 62 insertions(+), 44 deletions(-)

diff --git a/enricher/README.md b/enricher/README.md
index a61eac2..977482e 100644
--- a/enricher/README.md
+++ b/enricher/README.md
@@ -21,10 +21,12 @@ To start server:
 `$ bin/start_server.sh`  
 
 Server has single route in root which accepts POST json requests with parameters: 
-* "sentence": Target sentence. Word to find synonyms for must be marked with `#`
-* "lower", "upper" (Optional, substitute marking with `#`): Positions in the sentence of start and end of collocation to find synonyms for.  
+* "sentence": Target sentence. Number of word to find synonyms for must be passed as argument
+* "lower", "upper": Positions in the sentence of start and end of collocation to find synonyms for.  
 Note: sentence is split via whitespaces, upper bound is inclusive. 
 * "simple" (Optional): If set to true omits verbose data.  
+* "limit": Sets limit of result words number.  
 
 Simple request could be made with a script, e.g.  
-`$ bin/predict.sh "what is the chance of rain# tomorrow?"`
\ No newline at end of file
+`$ bin/predict.sh "what is the chance of rain tomorrow?" 5`  
+Would find synonym for word "rain" in this sentence.
diff --git a/enricher/bertft/bertft.py b/enricher/bertft/bertft.py
index 05c32a7..d315b49 100644
--- a/enricher/bertft/bertft.py
+++ b/enricher/bertft/bertft.py
@@ -81,7 +81,6 @@ class Pipeline:
 
         self.log.info("Server started in %s seconds", ('{0:.4f}'.format(time.time() - start_time)))
 
-    # TODO(?): remove split by #
     def find_top(self, sentence, positions, k, top_bert, bert_norm, min_ftext, weights):
         tokenizer = self.tokenizer
         model = self.model
@@ -89,36 +88,29 @@ class Pipeline:
 
         self.log.debug("Input: %s", sentence)
         start_time = time.time()
-        sentence_match = re.search("(\w+)#(\w+)?", sentence)
-        target = None
 
-        if sentence_match:
-            target = re.sub("#", "", sentence_match.group(1))
-            target = target.strip()
-            sequence = re.sub("(\w+)?#(\w+)?", tokenizer.mask_token, sentence)
+        lst = sentence.split()
+        lower = positions[0]
+        upper = positions[1] + 1
+        target = "-".join(lst[lower:upper])
+        if lower == positions[1] or target in self.ft_dict:
+            seqlst = lst[:lower]
+            seqlst.append(tokenizer.mask_token)
+            seqlst.extend(lst[upper:])
+            sequence = " ".join(seqlst)
         else:
-            lst = sentence.split()
-            lower = positions[0]
-            upper = positions[1] + 1
-            target = "-".join(lst[lower:upper])
-            if lower == positions[1] or target in self.ft_dict:
+            rec = list()
+
+            for i in range(lower, upper):
                 seqlst = lst[:lower]
-                seqlst.append(tokenizer.mask_token)
+                seqlst.append(lst[i])
                 seqlst.extend(lst[upper:])
-                sequence = " ".join(seqlst)
-            else:
-                rec = list()
-
-                for i in range(lower, upper):
-                    seqlst = lst[:lower]
-                    seqlst.append(lst[i])
-                    seqlst.extend(lst[upper:])
-                    rec.append(
-                        self.find_top(" ".join(seqlst), [lower, lower], k, top_bert, bert_norm, min_ftext, weights))
+                rec.append(
+                    self.find_top(" ".join(seqlst), [lower, lower], k, top_bert, bert_norm, min_ftext, weights))
 
-                rec = sorted(rec, key=lambda x: x.score.mean(), reverse=True)
+            rec = sorted(rec, key=lambda x: x.score.mean(), reverse=True)
 
-                return rec[0]
+            return rec[0]
 
         self.log.debug("Target word: %s; sequence: %s", target, sequence)
 
@@ -149,7 +141,7 @@ class Pipeline:
             sim = cosine_similarity(ft[target].reshape(1, -1), ft[word].reshape(1, -1))[0][0]
 
             sentence_sim = cosine_similarity(
-                ft.get_sentence_vector(re.sub("#", "", sentence)).reshape(1, -1),
+                ft.get_sentence_vector(sentence).reshape(1, -1),
                 ft.get_sentence_vector(re.sub(tokenizer.mask_token, word, sequence)).reshape(1, -1)
             )[0][0]
 
@@ -186,8 +178,8 @@ class Pipeline:
 
         return filtered_top
 
-    def do_find(self, s, positions):
-        return self.find_top(s, positions, 10, 200, 200, 0.25, [1, 1])
+    def do_find(self, s, positions, limit):
+        return self.find_top(s, positions, limit, 200, 200, 0.25, [1, 1])
 
     def dget(self, lst, pos):
         return list(map(lambda x: '{0:.2f}'.format(x[pos]), lst)) if self.on_run is not None else lget(lst, pos)
diff --git a/enricher/bin/install_dependencies.sh b/enricher/bin/install_dependencies.sh
index 4f88608..964bfbc 100755
--- a/enricher/bin/install_dependencies.sh
+++ b/enricher/bin/install_dependencies.sh
@@ -30,7 +30,7 @@ abort() {
   [ ! -f data/cc.en.300.bin ] && \
   { wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz -P data || \
   abort "Failed to download fast text data"; }
-[ ! -f data/cc.en.300.bin ] && { gunzip data/cc.en.300.bin.gz || abort "Failed to extract files"; }
+[ ! -f data/cc.en.300.bin ] && { gunzip -v data/cc.en.300.bin.gz || abort "Failed to extract files"; }
 
 [ ! -d /tmp/fastText/ ] && git clone https://github.com/facebookresearch/fastText.git /tmp/fastText
 pip3 install /tmp/fastText || abort "Failed to install fast text python module"
diff --git a/enricher/bin/predict.sh b/enricher/bin/predict.sh
index b2b4388..ef9d551 100755
--- a/enricher/bin/predict.sh
+++ b/enricher/bin/predict.sh
@@ -16,4 +16,4 @@
 # limitations under the License.
 #
 
-curl -d "{\"sentence\": \"$1\",\"simple\": true}" -H 'Content-Type: application/json' http://localhost:5000
+curl -d "{\"sentence\": \"$1\",\"simple\": true, \"lower\": $2, \"upper\": $2, \"limit\": 10}" -H 'Content-Type: application/json' http://localhost:5000/synonyms
diff --git a/enricher/bin/py_requirements b/enricher/bin/py_requirements
index 23b3646..af3a51e 100644
--- a/enricher/bin/py_requirements
+++ b/enricher/bin/py_requirements
@@ -17,7 +17,7 @@
 
 Flask==1.1.2
 transformers==2.7.0
-torch==1.4.0
+torch==1.5.0
 pandas==1.0.3
 scikit-learn==0.22.2.post1
 
diff --git a/enricher/bin/start_server.sh b/enricher/bin/start_server.sh
index 8e382e4..ec4e816 100755
--- a/enricher/bin/start_server.sh
+++ b/enricher/bin/start_server.sh
@@ -16,4 +16,4 @@
 # limitations under the License.
 #
 
-FLASK_APP=server.py python -m flask run
+FLASK_APP=server.py python3 -m flask run
diff --git a/enricher/server.py b/enricher/server.py
index bd21aea..bad42cb 100644
--- a/enricher/server.py
+++ b/enricher/server.py
@@ -17,8 +17,6 @@
 import logging
 from flask import Flask
 from flask import request
-from flask import abort
-from flask import Response
 from bertft import Pipeline
 
 logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.DEBUG)
@@ -28,17 +26,43 @@ app = Flask(__name__)
 pipeline = Pipeline()
 
 
-@app.route('/', methods=['POST'])
+class ValidationException(Exception):
+    def __init__(self, message):
+        super().__init__(message)
+
+
+@app.errorhandler(ValidationException)
+def handle_bad_request(e):
+    return str(e), 400
+
+
+def check_condition(condition, supplier, message):
+    if condition:
+        return supplier()
+    else:
+        raise ValidationException(message)
+
+
+def present(json, name):
+    return check_condition(name in json, lambda: json[name],
+                           "Required '" + name + "' argument is not present")
+
+
+@app.route('/synonyms', methods=['POST'])
 def main():
     if not request.is_json:
-        abort(Response("Json expected"))
+        raise ValidationException("Json expected")
 
     json = request.json
-    sentence = json['sentence']
-    upper = None if 'upper' not in json else json['upper']
-    lower = None if 'lower' not in json else json['lower']
-    positions = None if upper is None or lower is None else [lower, upper]
-    data = pipeline.do_find(sentence, positions)
+
+    sentence = present(json, 'sentence')
+    upper = present(json, 'upper')
+    lower = present(json, 'lower')
+    positions = check_condition(lower <= upper, lambda: [lower, upper],
+                                "Lower bound must be less or equal upper bound")
+    limit = present(json, 'limit')
+
+    data = pipeline.do_find(sentence, positions, limit)
     if 'simple' not in json or not json['simple']:
         json_data = data.to_json(orient='table', index=False)
     else: