You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nlpcraft.apache.org by if...@apache.org on 2020/06/07 21:22:24 UTC

[incubator-nlpcraft] 01/01: NLPCRAFT-67: Add batching

This is an automated email from the ASF dual-hosted git repository.

ifropc pushed a commit to branch NLPCRAFT-67
in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git

commit 94f518b607f48884b8f6ef2645649fc11ecb6192
Author: Ifropc <if...@apache.org>
AuthorDate: Sun Jun 7 14:21:02 2020 -0700

    NLPCRAFT-67: Add batching
---
 src/main/python/ctxword/bertft/bertft.py           | 119 ++++++++++-----------
 src/main/python/ctxword/bin/predict.sh             |   2 +-
 .../ctxword/jupyter/Trasnsformers-FastText.ipynb   |  25 +++--
 src/main/python/ctxword/server.py                  |  11 +-
 4 files changed, 80 insertions(+), 77 deletions(-)

diff --git a/src/main/python/ctxword/bertft/bertft.py b/src/main/python/ctxword/bertft/bertft.py
index 6d27d7a..16c20c5 100644
--- a/src/main/python/ctxword/bertft/bertft.py
+++ b/src/main/python/ctxword/bertft/bertft.py
@@ -33,10 +33,6 @@ def lget(lst, pos):
     return list(map(lambda x: x[pos], lst))
 
 
-def calc_w(x, y, w):
-    return x * w[0] + y * w[1]
-
-
 # TODO: make Model configurable
 # TODO: add type check
 class Pipeline:
@@ -77,11 +73,9 @@ class Pipeline:
         self.tokenizer = AutoTokenizer.from_pretrained("roberta-large")
         self.model = AutoModelWithLMHead.from_pretrained("roberta-large")
 
-        self.on_run = on_run
-
         self.log.info("Server started in %s seconds", ('{0:.4f}'.format(time.time() - start_time)))
 
-    def find_top(self, sentence, index, k, top_bert, bert_norm, min_ftext, weights, min_score):
+    def find_top(self, input_data, k, top_bert, min_ftext, weights, min_score):
         tokenizer = self.tokenizer
         model = self.model
         ft = self.ft
@@ -89,88 +83,91 @@ class Pipeline:
         k = 10 if k is None else k
         min_score = 0 if min_score is None else min_score
 
-        self.log.debug("Input: %s", sentence)
         start_time = time.time()
+        req_start_time = start_time
 
-        lst = sentence.split()
-
-        target = lst[index]
+        sentences = list(map(lambda x: self.replace_with_mask(x[0], x[1]), input_data))
 
-        seqlst = lst[:index]
-        seqlst.append(tokenizer.mask_token)
-        seqlst.extend(lst[(index + 1):])
-        sequence = " ".join(seqlst)
+        encoded = tokenizer.batch_encode_plus(list(map(lambda x: x[1], sentences)), pad_to_max_length=True)
+        input_ids = torch.tensor(encoded['input_ids'])
+        attention_mask = torch.tensor(encoded['attention_mask'])
 
-        self.log.debug("Target word: %s; sequence: %s", target, sequence)
+        start_time = self.print_time(start_time, "Tokenizing finished")
+        forward = model(input_ids=input_ids, attention_mask=attention_mask)
 
-        input = tokenizer.encode(sequence, return_tensors="pt")
-        mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]
+        start_time = self.print_time(start_time, "Batch finished (Bert)")
 
-        token_logits = model(input)[0]
+        mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]
+        token_logits = forward[0]
         mask_token_logits = token_logits[0, mask_token_index, :]
 
         # Filter top <top_bert> results of bert output
         topk = torch.topk(mask_token_logits, top_bert, dim=1)
-        top_tokens = list(zip(topk.indices[0].tolist(), topk.values[0].tolist()))
 
-        unfiltered = list()
-        filtered = list()
+        nvl = []
 
-        norm_d = top_tokens[bert_norm - 1][1]
-        norm_k = top_tokens[0][1] - norm_d
+        for d in topk.values:
+            nmin = torch.min(d)
+            nmax = torch.max(d)
+            nvl.append((d - nmin) / (nmax - nmin))
 
-        self.log.info("Bert finished in %s seconds", '{0:.4f}'.format(time.time() - start_time))
+        start_time = self.print_time(start_time, "Bert post-processing")
 
-        # Filter bert output by <min_ftext>
-        # TODO: calculate batch similarity
-        for token, value in top_tokens:
-            word = tokenizer.decode([token]).strip()
-            norm_value = (value - norm_d) / norm_k
+        suggestions = []
+        for index in topk.indices:
+            lst = list(index)
+            tmp = []
+            for single in lst:
+                tmp.append(tokenizer.decode([single]).strip())
+            suggestions.append(tuple(tmp))
 
-            sim = cosine_similarity(ft[target].reshape(1, -1), ft[word].reshape(1, -1))[0][0]
+        self.print_time(start_time, "Bert decoding")
 
-            sentence_sim = cosine_similarity(
-                ft.get_sentence_vector(sentence).reshape(1, -1),
-                ft.get_sentence_vector(re.sub(tokenizer.mask_token, word, sequence)).reshape(1, -1)
-            )[0][0]
+        cos = torch.nn.CosineSimilarity()
 
-            # Continue only for jupyter
-            if self.on_run is None and word == target:
-                continue
+        result = []
 
-            score = calc_w(norm_value, sim, weights)
+        for i in range(0, len(sentences)):
+            target = sentences[i][0]
+            suggest_embeddings = torch.tensor(list(map(lambda x: ft[x], suggestions[i])))
+            targ_tenzsor = torch.tensor(ft[target]).expand(suggest_embeddings.shape)
+            similarities = cos(targ_tenzsor, suggest_embeddings)
 
-            if sim >= min_ftext and score > min_score:
-                filtered.append((word, value, norm_value, sim, sentence_sim, score))
+            scores = nvl[i] * weights[0] + similarities * weights[1]
 
-            unfiltered.append((word, value, norm_value, sim, sentence_sim, score))
+            result.append(
+                sorted(
+                    filter(
+                        lambda x: x[0] > min_score and x[1] > min_ftext,
+                        zip(scores.tolist(), similarities.tolist(), suggestions[i], nvl[i].tolist())
+                    ),
+                    key=lambda x: x[0],
+                    reverse=True
+                )[:k]
+            )
 
-        done = (time.time() - start_time)
+        self.print_time(req_start_time, "Request processed")
 
-        kfiltered = filtered[:k]
-        kunfiltered = unfiltered[:k]
+        return result
 
-        kfiltered = sorted(kfiltered, key=lambda x: -x[len(x) - 1])
-        kunfiltered = sorted(kunfiltered, key=lambda x: -x[len(x) - 1])
+    def replace_with_mask(self, sentence, index):
+        lst = sentence.split()
 
-        filtered_top = pd.DataFrame({
-            'word': lget(kfiltered, 0),
-            'bert': self.dget(kfiltered, 1),
-            'normalized': self.dget(kfiltered, 2),
-            'ftext': self.dget(kfiltered, 3),
-            'ftext-sentence': self.dget(kfiltered, 4),
-            'score': lget(kfiltered, 5),
-        })
+        target = lst[index]
 
-        if self.on_run != None:
-            self.on_run(self, kunfiltered, unfiltered, filtered_top, target, tokenizer, top_tokens)
+        seqlst = lst[:index]
+        seqlst.append(self.tokenizer.mask_token)
+        seqlst.extend(lst[(index + 1):])
 
-        self.log.info("Processing finished in %s seconds", '{0:.4f}'.format(done))
+        return (target, " ".join(seqlst))
 
-        return filtered_top
+    def print_time(self, start, message):
+        current = time.time()
+        self.log.info(message + " in %s ms", '{0:.4f}'.format((current - start) * 1000))
+        return current
 
-    def do_find(self, s, index, limit, min_score):
-        return self.find_top(s, index, limit, 200, 200, 0.25, [1, 1], min_score)
+    def do_find(self, data, limit, min_score):
+        return self.find_top(data, limit, 100, 0.25, [1, 1], min_score)
 
     def dget(self, lst, pos):
         return list(map(lambda x: '{0:.2f}'.format(x[pos]), lst)) if self.on_run is not None else lget(lst, pos)
diff --git a/src/main/python/ctxword/bin/predict.sh b/src/main/python/ctxword/bin/predict.sh
index 0fd5f22..568c17a 100755
--- a/src/main/python/ctxword/bin/predict.sh
+++ b/src/main/python/ctxword/bin/predict.sh
@@ -16,4 +16,4 @@
 # limitations under the License.
 #
 
-curl -d "{\"sentence\": \"$1\",\"simple\": true, \"index\": $2, \"limit\": 10}" -H 'Content-Type: application/json' http://localhost:5000/suggestions
+curl -d "{\"sentences\": [[\"$1\", $2]], \"simple\": true, \"limit\": 10}" -H 'Content-Type: application/json' http://localhost:5000/suggestions
diff --git a/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb b/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb
index 8330cb0..59a000c 100644
--- a/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb
+++ b/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb
@@ -143,7 +143,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "pipeline = bertft.Pipeline(on_run)"
+    "pipeline = bertft.Pipeline()"
    ]
   },
   {
@@ -153,18 +153,25 @@
    "outputs": [],
    "source": [
     "# Example of usage\n",
-    "x = pipeline.find_top(\n",
-    "    \"what is the local weather forecast?\", # mark target word with #\n",
-    "    [4, 4], # or pass words position range (inclusive) in the sentece\n",
+    "res = pipeline.find_top(\n",
+    "    # List of sentences with target word position\n",
+    "    [\n",
+    "        (\"what is the local weather forecast?\", 4),\n",
+    "        (\"what is chances of rain tomorrow?\", 4),\n",
+    "        (\"what is chances of rain tomorrow?\", 2),\n",
+    "        (\"is driving a car faster then taking a bus?\", 3),\n",
+    "        (\"who is the best football player of all time?\", 4)\n",
+    "    ],\n",
     "    k = 20, # Filter best k results (by weighted score)\n",
-    "    top_bert = 200, # Number of initial filter of bert output \n",
-    "    bert_norm = 200, # Use this position for normalization of bert output \n",
+    "    top_bert = 100, # Number of initial filter of bert output \n",
     "    min_ftext = 0.3, # Minimal required score of fast text  \n",
     "    weights = [ # Weights of models scores to calculate total weighted score\n",
     "        1, # bert\n",
     "        1, # fast text\n",
-    "    ]\n",
-    ")"
+    "    ],\n",
+    "    min_score = 0 # Minimum required score\n",
+    ")\n",
+    "print(res)"
    ]
   },
   {
@@ -191,7 +198,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.2"
+   "version": "3.8.3"
   }
  },
  "nbformat": 4,
diff --git a/src/main/python/ctxword/server.py b/src/main/python/ctxword/server.py
index a5899a0..9846b56 100644
--- a/src/main/python/ctxword/server.py
+++ b/src/main/python/ctxword/server.py
@@ -17,6 +17,7 @@
 import logging
 from flask import Flask
 from flask import request
+from flask import jsonify
 from bertft import Pipeline
 
 logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.DEBUG)
@@ -55,14 +56,12 @@ def main():
 
     json = request.json
 
-    sentence = present(json, 'sentence')
-    index = present(json, 'index')
+    sentences = present(json, 'sentences')
     limit = json['limit'] if 'limit' in json else None
     min_score = json['min_score'] if 'min_score' in json else None
 
-    data = pipeline.do_find(sentence, index, limit, min_score)
+    data = pipeline.do_find(sentences, limit, min_score)
     if 'simple' not in json or not json['simple']:
-        json_data = data.to_json(orient='table', index=False)
+        return jsonify(data)
     else:
-        json_data = data['word'].to_json(orient='values')
-    return app.response_class(response=json_data, status=200, mimetype='application/json')
+        return jsonify(list(map(lambda x: list(map(lambda y: y[2], x)), data)))