You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nlpcraft.apache.org by if...@apache.org on 2020/06/07 22:42:43 UTC
[incubator-nlpcraft] 01/06: NLPCRAFT-67: Add auto-enrich using Bert
and FastTest models
This is an automated email from the ASF dual-hosted git repository.
ifropc pushed a commit to branch NLPCRAFT-67
in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git
commit 359ef9f98e05eebbebbb83a3d84a9d7cdfbcaec9
Author: Ifropc <if...@apache.org>
AuthorDate: Tue May 12 09:07:33 2020 -0700
NLPCRAFT-67: Add auto-enrich using Bert and FastTest models
Add apache license
Replace prints with logging
Move printing graphs to jupyter
---
.gitignore | 3 +
enricher/README.md | 30 ++++
enricher/bertft/__init__.py | 18 +++
enricher/bertft/bertft.py | 193 +++++++++++++++++++++++++
enricher/bertft/utils.py | 19 +++
enricher/bin/install_dependencies.sh | 39 +++++
enricher/bin/predict.sh | 19 +++
enricher/bin/py_requirements | 25 ++++
enricher/bin/start_server.sh | 19 +++
enricher/jupyter/Trasnsformers-FastText.ipynb | 199 ++++++++++++++++++++++++++
enricher/server.py | 46 ++++++
11 files changed, 610 insertions(+)
diff --git a/.gitignore b/.gitignore
index 7948152..1291eca 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,3 +11,6 @@ out
zips
model.yaml
model.json
+**/__pycache__
+**/enricher/data
+**/.ipynb_checkpoints/
diff --git a/enricher/README.md b/enricher/README.md
new file mode 100644
index 0000000..a61eac2
--- /dev/null
+++ b/enricher/README.md
@@ -0,0 +1,30 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+To install dependencies:
+`$ bin/install_dependencies.sh`
+To start server:
+`$ bin/start_server.sh`
+
+Server has single route in root which accepts POST json requests with parameters:
+* "sentence": Target sentence. Word to find synonyms for must be marked with `#`
+* "lower", "upper" (Optional, substitute marking with `#`): Positions in the sentence of start and end of collocation to find synonyms for.
+Note: sentence is split via whitespaces, upper bound is inclusive.
+* "simple" (Optional): If set to true omits verbose data.
+
+Simple request could be made with a script, e.g.
+`$ bin/predict.sh "what is the chance of rain# tomorrow?"`
\ No newline at end of file
diff --git a/enricher/bertft/__init__.py b/enricher/bertft/__init__.py
new file mode 100644
index 0000000..933bb48
--- /dev/null
+++ b/enricher/bertft/__init__.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .bertft import Pipeline
+from .bertft import lget
diff --git a/enricher/bertft/bertft.py b/enricher/bertft/bertft.py
new file mode 100644
index 0000000..05c32a7
--- /dev/null
+++ b/enricher/bertft/bertft.py
@@ -0,0 +1,193 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from transformers import AutoModelWithLMHead, AutoTokenizer
+import logging
+import torch
+import re
+import pandas as pd
+from sklearn.metrics.pairwise import cosine_similarity
+import time
+from pathlib import Path
+import fasttext.util
+from .utils import ROOT_DIR
+
+
+def lget(lst, pos):
+ return list(map(lambda x: x[pos], lst))
+
+
+def calc_w(x, y, w):
+ return x * w[0] + y * w[1]
+
+
+# TODO: make Model configurable
+# TODO: add type check
+class Pipeline:
+ def __init__(self, on_run=None):
+ self.log = logging.getLogger("bertft")
+
+ start_time = time.time()
+ # ft_size = 100 # ~2.6 GB
+ ft_size = 200 # ~4.5 GB
+ # ft_size = 300 # ~8 GB
+
+ self.ft_size = ft_size
+
+ def get_ft_path(n):
+ return ROOT_DIR + "/data/cc.en." + str(n) + ".bin"
+
+ cur_path = get_ft_path(ft_size)
+
+ self.log.warning("Initializing fast text")
+
+ if Path(cur_path).exists():
+ self.log.info("Found existing model, loading.")
+ ft = fasttext.load_model(cur_path)
+ else:
+ self.log.info("Configured model is not found. Loading default model.")
+ ft = fasttext.load_model(get_ft_path(300))
+
+ self.log.info("Compressing model")
+ fasttext.util.reduce_model(ft, ft_size)
+
+ ft.save_model(cur_path)
+
+ self.ft = ft
+ self.ft_dict = set(ft.get_words())
+
+ self.log.info("Loading bert")
+ # ~3 GB
+ self.tokenizer = AutoTokenizer.from_pretrained("roberta-large")
+ self.model = AutoModelWithLMHead.from_pretrained("roberta-large")
+
+ self.on_run = on_run
+
+ self.log.info("Server started in %s seconds", ('{0:.4f}'.format(time.time() - start_time)))
+
+ # TODO(?): remove split by #
+ def find_top(self, sentence, positions, k, top_bert, bert_norm, min_ftext, weights):
+ tokenizer = self.tokenizer
+ model = self.model
+ ft = self.ft
+
+ self.log.debug("Input: %s", sentence)
+ start_time = time.time()
+ sentence_match = re.search("(\w+)#(\w+)?", sentence)
+ target = None
+
+ if sentence_match:
+ target = re.sub("#", "", sentence_match.group(1))
+ target = target.strip()
+ sequence = re.sub("(\w+)?#(\w+)?", tokenizer.mask_token, sentence)
+ else:
+ lst = sentence.split()
+ lower = positions[0]
+ upper = positions[1] + 1
+ target = "-".join(lst[lower:upper])
+ if lower == positions[1] or target in self.ft_dict:
+ seqlst = lst[:lower]
+ seqlst.append(tokenizer.mask_token)
+ seqlst.extend(lst[upper:])
+ sequence = " ".join(seqlst)
+ else:
+ rec = list()
+
+ for i in range(lower, upper):
+ seqlst = lst[:lower]
+ seqlst.append(lst[i])
+ seqlst.extend(lst[upper:])
+ rec.append(
+ self.find_top(" ".join(seqlst), [lower, lower], k, top_bert, bert_norm, min_ftext, weights))
+
+ rec = sorted(rec, key=lambda x: x.score.mean(), reverse=True)
+
+ return rec[0]
+
+ self.log.debug("Target word: %s; sequence: %s", target, sequence)
+
+ input = tokenizer.encode(sequence, return_tensors="pt")
+ mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]
+
+ token_logits = model(input)[0]
+ mask_token_logits = token_logits[0, mask_token_index, :]
+
+ # Filter top <top_bert> results of bert output
+ topk = torch.topk(mask_token_logits, top_bert, dim=1)
+ top_tokens = list(zip(topk.indices[0].tolist(), topk.values[0].tolist()))
+
+ unfiltered = list()
+ filtered = list()
+
+ norm_d = top_tokens[bert_norm - 1][1]
+ norm_k = top_tokens[0][1] - norm_d
+
+ self.log.info("Bert finished in %s seconds", '{0:.4f}'.format(time.time() - start_time))
+
+ # Filter bert output by <min_ftext>
+ # TODO: calculate batch similarity
+ for token, value in top_tokens:
+ word = tokenizer.decode([token]).strip()
+ norm_value = (value - norm_d) / norm_k
+
+ sim = cosine_similarity(ft[target].reshape(1, -1), ft[word].reshape(1, -1))[0][0]
+
+ sentence_sim = cosine_similarity(
+ ft.get_sentence_vector(re.sub("#", "", sentence)).reshape(1, -1),
+ ft.get_sentence_vector(re.sub(tokenizer.mask_token, word, sequence)).reshape(1, -1)
+ )[0][0]
+
+ # Continue only for jupyter
+ if self.on_run is None and word == target:
+ continue
+
+ if sim >= min_ftext:
+ filtered.append((word, value, norm_value, sim, sentence_sim, calc_w(norm_value, sim, weights)))
+
+ unfiltered.append((word, value, norm_value, sim, sentence_sim, calc_w(norm_value, sim, weights)))
+
+ done = (time.time() - start_time)
+
+ kfiltered = filtered[:k]
+ kunfiltered = unfiltered[:k]
+
+ kfiltered = sorted(kfiltered, key=lambda x: -x[len(x) - 1])
+ kunfiltered = sorted(kunfiltered, key=lambda x: -x[len(x) - 1])
+
+ filtered_top = pd.DataFrame({
+ 'word': lget(kfiltered, 0),
+ 'bert': self.dget(kfiltered, 1),
+ 'normalized': self.dget(kfiltered, 2),
+ 'ftext': self.dget(kfiltered, 3),
+ 'ftext-sentence': self.dget(kfiltered, 4),
+ 'score': lget(kfiltered, 5),
+ })
+
+ if self.on_run != None:
+ self.on_run(self, kunfiltered, unfiltered, filtered_top, target, tokenizer, top_tokens)
+
+ self.log.info("Processing finished in %s seconds", '{0:.4f}'.format(done))
+
+ return filtered_top
+
+ def do_find(self, s, positions):
+ return self.find_top(s, positions, 10, 200, 200, 0.25, [1, 1])
+
+ def dget(self, lst, pos):
+ return list(map(lambda x: '{0:.2f}'.format(x[pos]), lst)) if self.on_run is not None else lget(lst, pos)
diff --git a/enricher/bertft/utils.py b/enricher/bertft/utils.py
new file mode 100644
index 0000000..af6aa8e
--- /dev/null
+++ b/enricher/bertft/utils.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + "/../"
+
diff --git a/enricher/bin/install_dependencies.sh b/enricher/bin/install_dependencies.sh
new file mode 100755
index 0000000..4f88608
--- /dev/null
+++ b/enricher/bin/install_dependencies.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+abort() {
+ echo "$1"
+ exit 1
+}
+
+[ -x "$(command -v wget)" ] || abort "wget not found"
+[ -x "$(command -v gunzip)" ] || abort "gunzip not found"
+[ -x "$(command -v python3)" ] || abort "python3 not found"
+[ -x "$(command -v pip3)" ] || abort "pip3 not found"
+
+[ ! -f data/cc.en.300.bin.gz ] && \
+ [ ! -f data/cc.en.300.bin ] && \
+ { wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz -P data || \
+ abort "Failed to download fast text data"; }
+[ ! -f data/cc.en.300.bin ] && { gunzip data/cc.en.300.bin.gz || abort "Failed to extract files"; }
+
+[ ! -d /tmp/fastText/ ] && git clone https://github.com/facebookresearch/fastText.git /tmp/fastText
+pip3 install /tmp/fastText || abort "Failed to install fast text python module"
+pip3 install -r bin/py_requirements || abort "Failed to install pip requirements from bin/py_requirements"
+
+rm -rf /tmp/fastText
diff --git a/enricher/bin/predict.sh b/enricher/bin/predict.sh
new file mode 100755
index 0000000..b2b4388
--- /dev/null
+++ b/enricher/bin/predict.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+curl -d "{\"sentence\": \"$1\",\"simple\": true}" -H 'Content-Type: application/json' http://localhost:5000
diff --git a/enricher/bin/py_requirements b/enricher/bin/py_requirements
new file mode 100644
index 0000000..23b3646
--- /dev/null
+++ b/enricher/bin/py_requirements
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+Flask==1.1.2
+transformers==2.7.0
+torch==1.4.0
+pandas==1.0.3
+scikit-learn==0.22.2.post1
+
+# Dependencies required only for jupyter. Uncomment lines below to install them.
+#matplotlib==3.2.1
diff --git a/enricher/bin/start_server.sh b/enricher/bin/start_server.sh
new file mode 100755
index 0000000..8e382e4
--- /dev/null
+++ b/enricher/bin/start_server.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+FLASK_APP=server.py python -m flask run
diff --git a/enricher/jupyter/Trasnsformers-FastText.ipynb b/enricher/jupyter/Trasnsformers-FastText.ipynb
new file mode 100644
index 0000000..8330cb0
--- /dev/null
+++ b/enricher/jupyter/Trasnsformers-FastText.ipynb
@@ -0,0 +1,199 @@
+{
+ "cells": [
+ {
+ "cell_type": "raw",
+ "metadata": {},
+ "source": [
+ "#\n",
+ "# Licensed to the Apache Software Foundation (ASF) under one or more\n",
+ "# contributor license agreements. See the NOTICE file distributed with\n",
+ "# this work for additional information regarding copyright ownership.\n",
+ "# The ASF licenses this file to You under the Apache License, Version 2.0\n",
+ "# (the \"License\"); you may not use this file except in compliance with\n",
+ "# the License. You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License.\n",
+ "#"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "import os\n",
+ "nb_dir = os.path.split(os.getcwd())[0]\n",
+ "sys.path.append(os.getcwd() + \"/../\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import logging\n",
+ "logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)\n",
+ "logging.getLogger(\"bertft\").setLevel(logging.DEBUG)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import bertft\n",
+ "from bertft import lget\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Important: auto-reload of bertft module\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def mk_graph(x1):\n",
+ " x1 = list(filter(lambda x: -2 < x < 0.99, x1))[:40]\n",
+ " kwargs = dict(alpha=0.3, bins=20)\n",
+ "\n",
+ " plt.hist(x1, **kwargs, color='g', label='FastText score')\n",
+ " plt.gca().set(title='Top 40 masks histogram of embeddings score', ylabel='Count')\n",
+ "\n",
+ " plt.legend()\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "def mk_graph2(x1):\n",
+ " kwargs = dict(alpha=1, bins=50)\n",
+ "\n",
+ " plt.hist(x1, **kwargs, color='r', label='Weighted score')\n",
+ " plt.gca().set(\n",
+ " title='Distribution of weighted score of top 200 unfiltered results (Target excluded)',\n",
+ " ylabel='Count'\n",
+ " )\n",
+ "\n",
+ " plt.legend()\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "def on_run(self, kunfiltered, unfiltered, filtered_top, target, tokenizer, top_tokens):\n",
+ " print(\"Unfiltered top:\")\n",
+ "\n",
+ " print(pd.DataFrame({\n",
+ " 'word': lget(kunfiltered, 0),\n",
+ " 'bert': self.dget(kunfiltered, 1),\n",
+ " 'normalized': self.dget(kunfiltered, 2),\n",
+ " 'ftext': self.dget(kunfiltered, 3),\n",
+ " 'ftext-sentence': self.dget(kunfiltered, 4),\n",
+ " 'score': lget(kunfiltered, 5),\n",
+ " }))\n",
+ "\n",
+ " print(\"Filtered top:\")\n",
+ "\n",
+ " print(filtered_top)\n",
+ "\n",
+ " mk_graph(lget(unfiltered, 2)[:100])\n",
+ " mk_graph2(lget(list(filter(lambda x: x[0] != target, unfiltered)), 4))\n",
+ "\n",
+ " if target is not None:\n",
+ " vec = tokenizer.encode(target, return_tensors=\"pt\")[0]\n",
+ " if len(vec) == 3:\n",
+ " tk = vec[1].item()\n",
+ " pos = None\n",
+ " score = None\n",
+ "\n",
+ " for e, (t, v) in enumerate(top_tokens):\n",
+ " if t == tk:\n",
+ " score = v\n",
+ " break\n",
+ " print(\"Original word position: %s; score: %s \" % (pos, score))\n",
+ " else:\n",
+ " if len(vec) > 3:\n",
+ " print(\"Original word is more then 1 token\")\n",
+ " print(tokenizer.tokenize(target))\n",
+ " else:\n",
+ " print(\"Original word wasn't found\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline = bertft.Pipeline(on_run)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Example of usage\n",
+ "x = pipeline.find_top(\n",
+ " \"what is the local weather forecast?\", # mark target word with #\n",
+ " [4, 4], # or pass words position range (inclusive) in the sentece\n",
+ " k = 20, # Filter best k results (by weighted score)\n",
+ " top_bert = 200, # Number of initial filter of bert output \n",
+ " bert_norm = 200, # Use this position for normalization of bert output \n",
+ " min_ftext = 0.3, # Minimal required score of fast text \n",
+ " weights = [ # Weights of models scores to calculate total weighted score\n",
+ " 1, # bert\n",
+ " 1, # fast text\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/enricher/server.py b/enricher/server.py
new file mode 100644
index 0000000..bd21aea
--- /dev/null
+++ b/enricher/server.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from flask import Flask
+from flask import request
+from flask import abort
+from flask import Response
+from bertft import Pipeline
+
+logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.DEBUG)
+
+app = Flask(__name__)
+
+pipeline = Pipeline()
+
+
+@app.route('/', methods=['POST'])
+def main():
+ if not request.is_json:
+ abort(Response("Json expected"))
+
+ json = request.json
+ sentence = json['sentence']
+ upper = None if 'upper' not in json else json['upper']
+ lower = None if 'lower' not in json else json['lower']
+ positions = None if upper is None or lower is None else [lower, upper]
+ data = pipeline.do_find(sentence, positions)
+ if 'simple' not in json or not json['simple']:
+ json_data = data.to_json(orient='table', index=False)
+ else:
+ json_data = data['word'].to_json(orient='values')
+ return app.response_class(response=json_data, status=200, mimetype='application/json')