You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nlpcraft.apache.org by if...@apache.org on 2020/10/10 02:38:31 UTC

[incubator-nlpcraft] 01/01: NLPCRAFT-93: Add sentence similarity with Bert

This is an automated email from the ASF dual-hosted git repository.

ifropc pushed a commit to branch NLPCRAFT-93
in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git

commit 769f1cffcdf8ce1f193438df66fd6ac7962d59af
Author: Ifropc <if...@apache.org>
AuthorDate: Fri Oct 9 19:38:07 2020 -0700

    NLPCRAFT-93: Add sentence similarity with Bert
---
 nlpcraft/src/main/python/ctxword/bertft/bertft.py  | 24 +++++++++++++++++++---
 .../src/main/python/ctxword/bin/py_requirements    |  2 +-
 .../ctxword/jupyter/Trasnsformers-FastText.ipynb   | 17 ++++++++++++---
 3 files changed, 36 insertions(+), 7 deletions(-)

diff --git a/nlpcraft/src/main/python/ctxword/bertft/bertft.py b/nlpcraft/src/main/python/ctxword/bertft/bertft.py
index 90b41cc..4ba405d 100644
--- a/nlpcraft/src/main/python/ctxword/bertft/bertft.py
+++ b/nlpcraft/src/main/python/ctxword/bertft/bertft.py
@@ -22,7 +22,7 @@ from pathlib import Path
 
 import fasttext.util
 import torch
-from transformers import AutoModelWithLMHead, AutoTokenizer
+from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModelForSequenceClassification
 
 from .utils import ROOT_DIR
 
@@ -80,7 +80,10 @@ class Pipeline:
         self.log.info("Loading bert")
         # ~3 GB
         self.tokenizer = AutoTokenizer.from_pretrained("roberta-large")
-        self.model = AutoModelWithLMHead.from_pretrained("roberta-large")
+        self.model = AutoModelForMaskedLM.from_pretrained("roberta-large")
+
+        self.classification_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
+        self.classification_model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc")
 
         if self.use_cuda:
             self.model.cuda()
@@ -113,7 +116,7 @@ class Pipeline:
                 (map(lambda x: self.replace_with_mask(x[0], x[1:]), input_data))
             )
 
-            encoded = tokenizer.batch_encode_plus(list(map(lambda x: x[1], sentences)), pad_to_max_length=True)
+            encoded = tokenizer.batch_encode_plus(list(map(lambda x: x[1], sentences)), padding='longest')
             input_ids = torch.tensor(encoded['input_ids'], device=self.device)
             attention_mask = torch.tensor(encoded['attention_mask'], device=self.device)
 
@@ -207,6 +210,21 @@ class Pipeline:
 
         return result
 
+    def sentence_similarity(self, s1, s2):
+        with torch.no_grad():
+            tokenizer = self.classification_tokenizer
+            model = self.classification_model
+
+            classes = ["not paraphrase", "is paraphrase"]
+
+            tokens = tokenizer(s1, s2, return_tensors="pt")
+
+            classification_logits = model(**tokens)[0]
+
+            results = torch.softmax(classification_logits, dim=1).tolist()[0]
+
+            print(results[1])
+
     def print_time(self, start, message):
         current = time.time()
         self.log.info(message + " in %s ms", '{0:.4f}'.format((current - start) * 1000))
diff --git a/nlpcraft/src/main/python/ctxword/bin/py_requirements b/nlpcraft/src/main/python/ctxword/bin/py_requirements
index bb0c20d..c234f76 100644
--- a/nlpcraft/src/main/python/ctxword/bin/py_requirements
+++ b/nlpcraft/src/main/python/ctxword/bin/py_requirements
@@ -17,5 +17,5 @@
 
 # Dependency list for 'ctxword' Python module.
 flask==1.1.2
-transformers==2.7.0
+transformers==3.3.1
 torch==1.6.0
diff --git a/nlpcraft/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb b/nlpcraft/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb
index 78b91f1..3cfda57 100644
--- a/nlpcraft/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb
+++ b/nlpcraft/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb
@@ -52,7 +52,6 @@
    "outputs": [],
    "source": [
     "import bertft\n",
-    "from bertft import lget\n",
     "import matplotlib.pyplot as plt\n",
     "import pandas as pd"
    ]
@@ -143,7 +142,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "pipeline = bertft.Pipeline()"
+    "pipeline = bertft.Pipeline(False)"
    ]
   },
   {
@@ -179,6 +178,18 @@
    "execution_count": null,
    "metadata": {},
    "outputs": [],
+   "source": [
+    "pipeline.sentence_similarity(\"Current weather in New York\", \"Would it rain in London?\")\n",
+    "pipeline.sentence_similarity(\"Current weather in London\", \"Would it rain in London?\")\n",
+    "pipeline.sentence_similarity(\"Current weather in New York\", \"Buy car in London\")\n",
+    "pipeline.sentence_similarity(\"Current weather in London\", \"Buy car in London\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
    "source": []
   }
  ],
@@ -198,7 +209,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.3"
+   "version": "3.8.5"
   }
  },
  "nbformat": 4,