You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nlpcraft.apache.org by if...@apache.org on 2020/10/10 02:38:31 UTC
[incubator-nlpcraft] 01/01: NLPCRAFT-93: Add sentence similarity
with Bert
This is an automated email from the ASF dual-hosted git repository.
ifropc pushed a commit to branch NLPCRAFT-93
in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git
commit 769f1cffcdf8ce1f193438df66fd6ac7962d59af
Author: Ifropc <if...@apache.org>
AuthorDate: Fri Oct 9 19:38:07 2020 -0700
NLPCRAFT-93: Add sentence similarity with Bert
---
nlpcraft/src/main/python/ctxword/bertft/bertft.py | 24 +++++++++++++++++++---
.../src/main/python/ctxword/bin/py_requirements | 2 +-
.../ctxword/jupyter/Trasnsformers-FastText.ipynb | 17 ++++++++++++---
3 files changed, 36 insertions(+), 7 deletions(-)
diff --git a/nlpcraft/src/main/python/ctxword/bertft/bertft.py b/nlpcraft/src/main/python/ctxword/bertft/bertft.py
index 90b41cc..4ba405d 100644
--- a/nlpcraft/src/main/python/ctxword/bertft/bertft.py
+++ b/nlpcraft/src/main/python/ctxword/bertft/bertft.py
@@ -22,7 +22,7 @@ from pathlib import Path
import fasttext.util
import torch
-from transformers import AutoModelWithLMHead, AutoTokenizer
+from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModelForSequenceClassification
from .utils import ROOT_DIR
@@ -80,7 +80,10 @@ class Pipeline:
self.log.info("Loading bert")
# ~3 GB
self.tokenizer = AutoTokenizer.from_pretrained("roberta-large")
- self.model = AutoModelWithLMHead.from_pretrained("roberta-large")
+ self.model = AutoModelForMaskedLM.from_pretrained("roberta-large")
+
+ self.classification_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
+ self.classification_model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc")
if self.use_cuda:
self.model.cuda()
@@ -113,7 +116,7 @@ class Pipeline:
(map(lambda x: self.replace_with_mask(x[0], x[1:]), input_data))
)
- encoded = tokenizer.batch_encode_plus(list(map(lambda x: x[1], sentences)), pad_to_max_length=True)
+ encoded = tokenizer.batch_encode_plus(list(map(lambda x: x[1], sentences)), padding='longest')
input_ids = torch.tensor(encoded['input_ids'], device=self.device)
attention_mask = torch.tensor(encoded['attention_mask'], device=self.device)
@@ -207,6 +210,21 @@ class Pipeline:
return result
+ def sentence_similarity(self, s1, s2):
+ with torch.no_grad():
+ tokenizer = self.classification_tokenizer
+ model = self.classification_model
+
+ classes = ["not paraphrase", "is paraphrase"]
+
+ tokens = tokenizer(s1, s2, return_tensors="pt")
+
+ classification_logits = model(**tokens)[0]
+
+ results = torch.softmax(classification_logits, dim=1).tolist()[0]
+
+ print(results[1])
+
def print_time(self, start, message):
current = time.time()
self.log.info(message + " in %s ms", '{0:.4f}'.format((current - start) * 1000))
diff --git a/nlpcraft/src/main/python/ctxword/bin/py_requirements b/nlpcraft/src/main/python/ctxword/bin/py_requirements
index bb0c20d..c234f76 100644
--- a/nlpcraft/src/main/python/ctxword/bin/py_requirements
+++ b/nlpcraft/src/main/python/ctxword/bin/py_requirements
@@ -17,5 +17,5 @@
# Dependency list for 'ctxword' Python module.
flask==1.1.2
-transformers==2.7.0
+transformers==3.3.1
torch==1.6.0
diff --git a/nlpcraft/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb b/nlpcraft/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb
index 78b91f1..3cfda57 100644
--- a/nlpcraft/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb
+++ b/nlpcraft/src/main/python/ctxword/jupyter/Trasnsformers-FastText.ipynb
@@ -52,7 +52,6 @@
"outputs": [],
"source": [
"import bertft\n",
- "from bertft import lget\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd"
]
@@ -143,7 +142,7 @@
"metadata": {},
"outputs": [],
"source": [
- "pipeline = bertft.Pipeline()"
+ "pipeline = bertft.Pipeline(False)"
]
},
{
@@ -179,6 +178,18 @@
"execution_count": null,
"metadata": {},
"outputs": [],
+ "source": [
+ "pipeline.sentence_similarity(\"Current weather in New York\", \"Would it rain in London?\")\n",
+ "pipeline.sentence_similarity(\"Current weather in London\", \"Would it rain in London?\")\n",
+ "pipeline.sentence_similarity(\"Current weather in New York\", \"Buy car in London\")\n",
+ "pipeline.sentence_similarity(\"Current weather in London\", \"Buy car in London\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": []
}
],
@@ -198,7 +209,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.3"
+ "version": "3.8.5"
}
},
"nbformat": 4,