You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/01/11 07:54:33 UTC
[incubator-mxnet] branch master updated: Add mxnet.text APIs (#8763)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6c1f4f7 Add mxnet.text APIs (#8763)
6c1f4f7 is described below
commit 6c1f4f7023104ba936322fa617d8653e6ed4fbfb
Author: Aston Zhang <as...@amazon.com>
AuthorDate: Wed Jan 10 23:54:25 2018 -0800
Add mxnet.text APIs (#8763)
* Add text utils
* Leftovers
* revise
* before load embeddings
* glossary done
* Add/revise text utils, revise test cases
* Add docstrings
* clean package init
* remove play
* Resolve issues and complete docstrings
* disable pylint
* Remove tqdm dependency
* Add encoding
utf8
utf
utf
utf
* remove non-ascii
* fix textcase
* remove decode in glossary
* py2 unicode
* Fix py2 error
* add tests
* Test all embds
* test some embeds
* Add getter for glossary
* remove util from path, revise interfaces of glossary
* skip some test, before major revise
* Add TextIndexer, only TextEmbed needs revised before major revise
* before major revise
* minor update
* Revise TextIndexer with test
* lint
* lint
* Revise TextEmbed, FastText, Glove, CustmonEmbed with test
* Revision done except for docstr
* Add unit tests for utils
* almost no pylint disable, yeah
* doc minor updates
* re-run
* re-run
* except for register
* except for register
* Revise register/create, add get_registry
* revise
* More readability
* py2 compatibility
* Update doc
* Revise based on feedbacks from NLP team
* add init
* Support indexing for any hashable and comparable token
* Add test cases
* remove type cmp
* Fix doc error and add API descriptions
* Fix api doc error
* add members explicitly
* re-order modules in text.md
* url in one line
* add property desc for all inherited classes for rst parsing
* escape \n
* update glossary example
* escape \n
* add use case
* Make doc more user-friendly
* proper imports, gluon.nn.Embedding use case
* fix links
* re-org link level
* tokens_to_indices
* to_indices, to_tokens
---
docs/api/python/index.md | 9 +
docs/api/python/text/text.md | 443 ++++++++++++++++++++++
python/mxnet/registry.py | 17 +
python/mxnet/text/__init__.py | 25 ++
python/mxnet/text/constants.py | 24 ++
python/mxnet/text/embedding.py | 681 +++++++++++++++++++++++++++++++++
python/mxnet/text/glossary.py | 142 +++++++
python/mxnet/text/indexer.py | 231 ++++++++++++
python/mxnet/text/utils.py | 79 ++++
tests/python/unittest/test_text.py | 743 +++++++++++++++++++++++++++++++++++++
10 files changed, 2394 insertions(+)
diff --git a/docs/api/python/index.md b/docs/api/python/index.md
index 75ff186..7a3ad7c 100644
--- a/docs/api/python/index.md
+++ b/docs/api/python/index.md
@@ -98,6 +98,15 @@ imported by running:
io/io.md
```
+## Text API
+
+```eval_rst
+.. toctree::
+ :maxdepth: 1
+
+ text/text.md
+```
+
## Image API
```eval_rst
diff --git a/docs/api/python/text/text.md b/docs/api/python/text/text.md
new file mode 100644
index 0000000..a448ae4
--- /dev/null
+++ b/docs/api/python/text/text.md
@@ -0,0 +1,443 @@
+# Text API
+
+## Overview
+
+The mxnet.text APIs refer to classes and functions related to text data
+processing, such as bulding indices and loading pre-trained embedding vectors
+for text tokens and storing them in the `mxnet.ndarray.NDArray` format.
+
+This document lists the text APIs in mxnet:
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ mxnet.text.glossary
+ mxnet.text.embedding
+ mxnet.text.indexer
+ mxnet.text.utils
+```
+
+All the code demonstrated in this document assumes that the following modules
+or packages are imported.
+
+```python
+>>> from mxnet import gluon
+>>> from mxnet import nd
+>>> from mxnet import text
+>>> import collections
+
+```
+
+### Look up pre-trained word embeddings for indexed words
+
+As a common use case, let us look up pre-trained word embedding vectors for
+indexed words in just a few lines of code. To begin with, we can create a
+fastText word embedding object by specifying the embedding name `fasttext` and
+the pre-trained file `wiki.simple.vec`.
+
+```python
+>>> fasttext_simple = text.embedding.TokenEmbedding.create('fasttext',
+... pretrained_file_name='wiki.simple.vec')
+
+```
+
+Suppose that we have a simple text data set in the string format. We can count
+word frequency in the data set.
+
+```python
+>>> text_data = " hello world \n hello nice world \n hi world \n"
+>>> counter = text.utils.count_tokens_from_str(text_data)
+
+```
+
+The obtained `counter` has key-value pairs whose keys are words and values are
+word frequencies. Suppose that we want to build indices for all the keys in
+`counter` and load the defined fastText word embedding for all such indexed
+words. We can create a glossary object by specifying `counter` and
+`fasttext_simple` as its argument.
+
+```python
+>>> glossary = text.glossary.Glossary(counter, fasttext_simple)
+
+```
+
+Now we are ready to look up the fastText word embedding vectors for indexed
+words.
+
+```python
+>>> glossary.get_vecs_by_tokens(['hello', 'world'])
+
+[[ 3.95669997e-01 2.14540005e-01 -3.53889987e-02 -2.42990002e-01
+ ...
+ -7.54180014e-01 -3.14429998e-01 2.40180008e-02 -7.61009976e-02]
+ [ 1.04440004e-01 -1.08580001e-01 2.72119999e-01 1.32990003e-01
+ ...
+ -3.73499990e-01 5.67310005e-02 5.60180008e-01 2.90190000e-02]]
+<NDArray 2x300 @cpu(0)>
+
+```
+
+### Use `glossary` in `gluon`
+
+To demonstrate how to use a glossary with the loaded word embedding in the
+`gluon` package, let us first obtain indices of the words 'hello' and 'world'.
+
+```python
+>>> glossary.to_indices(['hello', 'world'])
+[2, 1]
+
+```
+
+We can obtain the vector representation for the words 'hello' and 'world'
+by specifying their indices (2 and 1) and the `glossary.idx_to_vec` in
+`mxnet.gluon.nn.Embedding`.
+
+```python
+>>> layer = gluon.nn.Embedding(len(glossary), glossary.vec_len)
+>>> layer.initialize()
+>>> layer.weight.set_data(glossary.idx_to_vec)
+>>> layer(nd.array([2, 1]))
+
+[[ 3.95669997e-01 2.14540005e-01 -3.53889987e-02 -2.42990002e-01
+ ...
+ -7.54180014e-01 -3.14429998e-01 2.40180008e-02 -7.61009976e-02]
+ [ 1.04440004e-01 -1.08580001e-01 2.72119999e-01 1.32990003e-01
+ ...
+ -3.73499990e-01 5.67310005e-02 5.60180008e-01 2.90190000e-02]]
+<NDArray 2x300 @cpu(0)>
+
+```
+
+
+## Glossary
+
+The glossary provides indexing and embedding for text tokens in a glossary. For
+each indexed token in a glossary, an embedding vector will be associated with
+it. Such embedding vectors can be loaded from externally hosted or custom
+pre-trained token embedding files, such as via instances of
+[`TokenEmbedding`](#mxnet.text.embedding.TokenEmbedding).
+The input counter whose keys are
+candidate indices may be obtained via
+[`count_tokens_from_str`](#mxnet.text.utils.count_tokens_from_str).
+
+```eval_rst
+.. currentmodule:: mxnet.text.glossary
+.. autosummary::
+ :nosignatures:
+
+ Glossary
+```
+
+To get all the valid names for pre-trained embeddings and files, we can use
+[`TokenEmbedding.get_embedding_and_pretrained_file_names`](#mxnet.text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names).
+
+```python
+>>> text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names()
+{'glove': ['glove.42B.300d.txt', 'glove.6B.50d.txt', 'glove.6B.100d.txt',
+'glove.6B.200d.txt', 'glove.6B.300d.txt', 'glove.840B.300d.txt',
+'glove.twitter.27B.25d.txt', 'glove.twitter.27B.50d.txt',
+'glove.twitter.27B.100d.txt', 'glove.twitter.27B.200d.txt'],
+'fasttext': ['wiki.en.vec', 'wiki.simple.vec', 'wiki.zh.vec']}
+
+```
+
+To begin with, we can create a fastText word embedding object by specifying the
+embedding name `fasttext` and the pre-trained file `wiki.simple.vec`.
+
+```python
+>>> fasttext_simple = text.embedding.TokenEmbedding.create('fasttext',
+... pretrained_file_name='wiki.simple.vec')
+
+```
+
+Suppose that we have a simple text data set in the string format. We can count
+word frequency in the data set.
+
+```python
+>>> text_data = " hello world \n hello nice world \n hi world \n"
+>>> counter = text.utils.count_tokens_from_str(text_data)
+
+```
+
+The obtained `counter` has key-value pairs whose keys are words and values are
+word frequencies. Suppose that we want to build indices for the most frequent 2
+keys in `counter` and load the defined fastText word embedding for all these
+2 words.
+
+```python
+>>> glossary = text.glossary.Glossary(counter, fasttext_simple, most_freq_count=2)
+
+```
+
+Now we are ready to look up the fastText word embedding vectors for indexed
+words.
+
+```python
+>>> glossary.get_vecs_by_tokens(['hello', 'world'])
+
+[[ 3.95669997e-01 2.14540005e-01 -3.53889987e-02 -2.42990002e-01
+ ...
+ -7.54180014e-01 -3.14429998e-01 2.40180008e-02 -7.61009976e-02]
+ [ 1.04440004e-01 -1.08580001e-01 2.72119999e-01 1.32990003e-01
+ ...
+ -3.73499990e-01 5.67310005e-02 5.60180008e-01 2.90190000e-02]]
+<NDArray 2x300 @cpu(0)>
+
+```
+
+We can also access properties such as `token_to_idx` (mapping tokens to
+indices), `idx_to_token` (mapping indices to tokens), and `vec_len`
+(length of each embedding vector).
+
+```python
+>>> glossary.token_to_idx
+{'<unk>': 0, 'world': 1, 'hello': 2, 'hi': 3, 'nice': 4}
+>>> glossary.idx_to_token
+['<unk>', 'world', 'hello', 'hi', 'nice']
+>>> len(glossary)
+5
+>>> glossary.vec_len
+300
+
+```
+
+If a token is unknown to `glossary`, its embedding vector is initialized
+according to the default specification in `fasttext_simple` (all elements are
+0).
+
+```python
+
+>>> glossary.get_vecs_by_tokens('unknownT0kEN')
+
+[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
+ ...
+ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
+<NDArray 300 @cpu(0)>
+
+```
+
+## Text token embedding
+
+The text token embedding builds indices for text tokens. Such indexed tokens can
+be used by instances of [`TokenEmbedding`](#mxnet.text.embedding.TokenEmbedding)
+and [`Glossary`](#mxnet.text.glossary.Glossary).
+
+To load token embeddings from an externally hosted pre-trained token embedding
+file, such as those of GloVe and FastText, use
+[`TokenEmbedding.create(embedding_name, pretrained_file_name)`](#mxnet.text.embedding.TokenEmbedding.create).
+To get all the available `embedding_name` and `pretrained_file_name`, use
+[`TokenEmbedding.get_embedding_and_pretrained_file_names()`](#mxnet.text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names).
+
+Alternatively, to load embedding vectors from a custom pre-trained text token
+embedding file, use [`CustomEmbedding`](#mxnet.text.embedding.CustomEmbedding).
+
+
+```eval_rst
+.. currentmodule:: mxnet.text.embedding
+.. autosummary::
+ :nosignatures:
+
+ TokenEmbedding
+ GloVe
+ FastText
+ CustomEmbedding
+```
+
+To get all the valid names for pre-trained embeddings and files, we can use
+[`TokenEmbedding.get_embedding_and_pretrained_file_names`](#mxnet.text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names).
+
+```python
+>>> text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names()
+{'glove': ['glove.42B.300d.txt', 'glove.6B.50d.txt', 'glove.6B.100d.txt',
+'glove.6B.200d.txt', 'glove.6B.300d.txt', 'glove.840B.300d.txt',
+'glove.twitter.27B.25d.txt', 'glove.twitter.27B.50d.txt',
+'glove.twitter.27B.100d.txt', 'glove.twitter.27B.200d.txt'],
+'fasttext': ['wiki.en.vec', 'wiki.simple.vec', 'wiki.zh.vec']}
+
+```
+
+To begin with, we can create a GloVe word embedding object by specifying the
+embedding name `glove` and the pre-trained file `glove.6B.50d.txt`. The
+argument `init_unknown_vec` specifies default vector representation for any
+unknown token.
+
+```python
+>>> glove_6b_50d = text.embedding.TokenEmbedding.create('glove',
+... pretrained_file_name='glove.6B.50d.txt', init_unknown_vec=nd.zeros)
+
+```
+
+We can access properties such as `token_to_idx` (mapping tokens to indices),
+`idx_to_token` (mapping indices to tokens), `vec_len` (length of each embedding
+vector), and `unknown_token` (representation of any unknown token, default
+value is '<unk>').
+
+```python
+>>> glove_6b_50d.token_to_idx['hi']
+11084
+>>> glove_6b_50d.idx_to_token[11084]
+'hi'
+>>> glove_6b_50d.vec_len
+50
+>>> glove_6b_50d.unknown_token
+'<unk>'
+
+```
+
+For every unknown token, if its representation '<unk>' is encountered in the
+pre-trained token embedding file, index 0 of property `idx_to_vec` maps to the
+pre-trained token embedding vector loaded from the file; otherwise, index 0 of
+property `idx_to_vec` maps to the default token embedding vector specified via
+`init_unknown_vec` (set to nd.zeros here). Since the pre-trained file
+does not have a vector for the token '<unk>', index 0 has to map to an
+additional token '<unk>' and the number of tokens in the embedding is 400,001.
+
+
+```python
+>>> len(glove_6b_50d)
+400001
+>>> glove_6b_50d.idx_to_vec[0]
+
+[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
+ ...
+ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
+<NDArray 50 @cpu(0)>
+>>> glove_6b_50d.get_vecs_by_tokens('unknownT0kEN')
+
+[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
+ ...
+ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
+<NDArray 50 @cpu(0)>
+>>> glove_6b_50d.get_vecs_by_tokens(['unknownT0kEN', 'unknownT0kEN'])
+
+[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
+ ...
+ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
+ [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
+ ...
+ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
+<NDArray 2x50 @cpu(0)>
+
+```
+
+
+### Implement a new text token embedding
+
+For ``optimizer``, create a subclass of
+[`TokenEmbedding`](#mxnet.text.embedding.TokenEmbedding).
+Also add ``@TokenEmbedding.register`` before this class. See
+[`embedding.py`](https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/text/embedding.py)
+for examples.
+
+
+## Text token indexer
+
+The text token indexer builds indices for text tokens. Such indexed tokens can
+be used by instances of [`TokenEmbedding`](#mxnet.text.embedding.TokenEmbedding)
+and [`Glossary`](#mxnet.text.glossary.Glossary). The input
+counter whose keys are candidate indices may be obtained via
+[`count_tokens_from_str`](#mxnet.text.utils.count_tokens_from_str).
+
+
+```eval_rst
+.. currentmodule:: mxnet.text.indexer
+.. autosummary::
+ :nosignatures:
+
+ TokenIndexer
+```
+
+Suppose that we have a simple text data set in the string format. We can count
+word frequency in the data set.
+
+```python
+>>> text_data = " hello world \n hello nice world \n hi world \n"
+>>> counter = text.utils.count_tokens_from_str(text_data)
+
+```
+
+The obtained `counter` has key-value pairs whose keys are words and values are
+word frequencies. Suppose that we want to build indices for the 2 most frequent
+keys in `counter` with the unknown token representation '<UnK>' and a reserved
+token '<pad>'.
+
+```python
+>>> token_indexer = text.indexer.TokenIndexer(counter, most_freq_count=2,
+... unknown_token='<UnK>', reserved_tokens=['<pad>'])
+
+```
+
+We can access properties such as `token_to_idx` (mapping tokens to indices),
+`idx_to_token` (mapping indices to tokens), `vec_len` (length of each embedding
+vector), and `unknown_token` (representation of any unknown token) and
+`reserved_tokens`.
+
+```python
+>>> token_indexer = text.indexer.TokenIndexer(counter, most_freq_count=2,
+... unknown_token='<UnK>', reserved_tokens=['<pad>'])
+
+```
+
+```python
+>>> token_indexer.token_to_idx
+{'<UnK>': 0, '<pad>': 1, 'world': 2, 'hello': 3}
+>>> token_indexer.idx_to_token
+['<UnK>', '<pad>', 'world', 'hello']
+>>> token_indexer.unknown_token
+'<UnK>'
+>>> token_indexer.reserved_tokens
+['<pad>']
+>>> len(token_indexer)
+4
+```
+
+Besides the specified unknown token '<UnK>' and reserved_token '<pad>' are
+indexed, the 2 most frequent words 'world' and 'hello' are also indexed.
+
+
+
+## Text utilities
+
+The following functions provide utilities for text data processing.
+
+```eval_rst
+.. currentmodule:: mxnet.text.utils
+.. autosummary::
+ :nosignatures:
+
+ count_tokens_from_str
+```
+
+
+
+
+## API Reference
+
+<script type="text/javascript" src='../../_static/js/auto_module_index.js'></script>
+
+```eval_rst
+
+.. automodule:: mxnet.text.glossary
+.. autoclass:: mxnet.text.glossary.Glossary
+ :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
+
+.. automodule:: mxnet.text.embedding
+.. autoclass:: mxnet.text.embedding.TokenEmbedding
+ :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens, register, create, get_embedding_and_pretrained_file_names
+.. autoclass:: mxnet.text.embedding.GloVe
+ :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
+.. autoclass:: mxnet.text.embedding.FastText
+ :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
+.. autoclass:: mxnet.text.embedding.CustomEmbedding
+ :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
+
+.. automodule:: mxnet.text.indexer
+.. autoclass:: mxnet.text.indexer.TokenIndexer
+ :members: to_indices, to_tokens
+
+.. automodule:: mxnet.text.utils
+ :members: count_tokens_from_str
+
+```
+<script>auto_index("api-reference");</script>
diff --git a/python/mxnet/registry.py b/python/mxnet/registry.py
index 4a4f22f..eaae920 100644
--- a/python/mxnet/registry.py
+++ b/python/mxnet/registry.py
@@ -29,6 +29,23 @@ from .base import string_types
_REGISTRY = {}
+def get_registry(base_class):
+ """Get registrator.
+
+ Parameters
+ ----------
+ base_class : type
+ base class for classes that will be registered
+
+ Returns
+ -------
+ a registrator
+ """
+ if base_class not in _REGISTRY:
+ _REGISTRY[base_class] = {}
+ return _REGISTRY[base_class].copy()
+
+
def get_register_func(base_class, nickname):
"""Get registrator function.
diff --git a/python/mxnet/text/__init__.py b/python/mxnet/text/__init__.py
new file mode 100644
index 0000000..16035b7
--- /dev/null
+++ b/python/mxnet/text/__init__.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Text utilities."""
+
+from . import utils
+from . import constants
+from . import indexer
+from . import embedding
+from . import glossary
diff --git a/python/mxnet/text/constants.py b/python/mxnet/text/constants.py
new file mode 100644
index 0000000..a36d5af
--- /dev/null
+++ b/python/mxnet/text/constants.py
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+"""Read text files and load embeddings."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+UNKNOWN_IDX = 0
diff --git a/python/mxnet/text/embedding.py b/python/mxnet/text/embedding.py
new file mode 100644
index 0000000..5b45e58
--- /dev/null
+++ b/python/mxnet/text/embedding.py
@@ -0,0 +1,681 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=consider-iterating-dictionary
+
+"""Text token embeddings."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+import io
+import logging
+import os
+import tarfile
+import warnings
+import zipfile
+
+from . import constants as C
+from ..gluon.utils import download
+from .indexer import TokenIndexer
+from .. import ndarray as nd
+from .. import registry
+
+
+class TokenEmbedding(TokenIndexer):
+ """Token embedding base class.
+
+
+ To load token embeddings from an externally hosted pre-trained
+ token embedding file, such as those of GloVe and FastText, use
+ `TokenEmbedding.create(embedding_name, pretrained_file_name)`. To get all
+ the available `embedding_name` and `pretrained_file_name`, use
+ `TokenEmbedding.get_embedding_and_pretrained_file_names()`.
+
+ Alternatively, to load embedding vectors from a custom pre-trained token
+ embedding file, use :class:`~mxnet.text.embedding.CustomEmbedding`.
+
+ For every unknown token, if its representation `self.unknown_token` is
+ encountered in the pre-trained token embedding file, index 0 of
+ `self.idx_to_vec` maps to the pre-trained token embedding vector loaded from
+ the file; otherwise, index 0 of `self.idx_to_vec` maps to the token
+ embedding vector initialized by `init_unknown_vec`.
+
+ If a token is encountered multiple times in the pre-trained token embedding
+ file, only the first-encountered token embedding vector will be loaded and
+ the rest will be skipped.
+
+ For the same token, its index and embedding vector may vary across different
+ instances of :class:`~mxnet.text.embedding.TokenEmbedding`.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ vec_len : int
+ The length of the embedding vector for each token.
+ idx_to_vec : mxnet.ndarray.NDArray
+ For all the indexed tokens in this embedding, this NDArray maps each
+ token's index to an embedding vector. The largest valid index maps
+ to the initialized embedding vector for every reserved token, such as an
+ unknown_token token and a padding token.
+ """
+
+ def __init__(self, **kwargs):
+ super(TokenEmbedding, self).__init__(**kwargs)
+
+ @classmethod
+ def _get_pretrained_file_path_from_url(cls, url, embedding_root,
+ pretrained_file_name):
+ """Get the local path to the pre-trained token embedding file from url.
+
+
+ The pre-trained embedding file will be downloaded from url if it has not
+ been downloaded yet or the existing file fails to match its expected
+ SHA-1 hash.
+ """
+
+ embedding_cls = cls.__name__.lower()
+ embedding_root = os.path.expanduser(embedding_root)
+
+ embedding_dir = os.path.join(embedding_root, embedding_cls)
+ pretrained_file_path = os.path.join(embedding_dir, pretrained_file_name)
+ downloaded_file = os.path.basename(url)
+ downloaded_file_path = os.path.join(embedding_dir, downloaded_file)
+
+ expected_file_hash = cls.pretrained_file_name_sha1[pretrained_file_name]
+
+ if hasattr(cls, 'pretrained_archive_name_sha1'):
+ expected_downloaded_hash = \
+ cls.pretrained_archive_name_sha1[downloaded_file]
+ else:
+ expected_downloaded_hash = expected_file_hash
+
+ # If downloaded_file_path exists and matches expected_downloaded_hash,
+ # there is no need to download.
+ download(url, downloaded_file_path, sha1_hash=expected_downloaded_hash)
+
+ ext = os.path.splitext(downloaded_file)[1]
+ if ext == '.zip':
+ with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
+ zf.extractall(embedding_dir)
+ elif ext == '.gz':
+ with tarfile.open(downloaded_file_path, 'r:gz') as tar:
+ tar.extractall(path=embedding_dir)
+ return pretrained_file_path
+
+ def _load_embedding(self, pretrained_file_path, elem_delim,
+ init_unknown_vec, encoding='utf8'):
+ """Load embedding vectors from the pre-trained token embedding file.
+
+
+ For every unknown token, if its representation `self.unknown_token` is
+ encountered in the pre-trained token embedding file, index 0 of
+ `self.idx_to_vec` maps to the pre-trained token embedding vector loaded
+ from the file; otherwise, index 0 of `self.idx_to_vec` maps to the text
+ embedding vector initialized by `init_unknown_vec`.
+
+ If a token is encountered multiple times in the pre-trained text
+ embedding file, only the first-encountered token embedding vector will
+ be loaded and the rest will be skipped.
+ """
+
+ pretrained_file_path = os.path.expanduser(pretrained_file_path)
+
+ if not os.path.isfile(pretrained_file_path):
+ raise ValueError('`pretrained_file_path` must be a valid path to '
+ 'the pre-trained token embedding file.')
+
+ with io.open(pretrained_file_path, 'r', encoding=encoding) as f:
+ lines = f.readlines()
+
+ logging.info('Loading pre-trained token embedding vectors from %s',
+ pretrained_file_path)
+
+ vec_len = None
+ all_elems = []
+ tokens = set()
+ loaded_unknown_vec = None
+ line_num = 0
+ for line in lines:
+ line_num += 1
+ elems = line.rstrip().split(elem_delim)
+
+ assert len(elems) > 1, 'At line %d of the pre-trained text ' \
+ 'embedding file: the data format of the ' \
+ 'pre-trained token embedding file %s is ' \
+ 'unexpected.' \
+ % (line_num, pretrained_file_path)
+
+ token, elems = elems[0], [float(i) for i in elems[1:]]
+
+ if token == self.unknown_token and loaded_unknown_vec is None:
+ loaded_unknown_vec = elems
+ tokens.add(self.unknown_token)
+ elif token in tokens:
+ warnings.warn('At line %d of the pre-trained token embedding '
+ 'file: the embedding vector for token %s has '
+ 'been loaded and a duplicate embedding for the '
+ 'same token is seen and skipped.'
+ % (line_num, token))
+ elif len(elems) == 1:
+ warnings.warn('At line %d of the pre-trained text '
+ 'embedding file: token %s with 1-dimensional '
+ 'vector %s is likely a header and is '
+ 'skipped.' % (line_num, token, elems))
+ else:
+ if vec_len is None:
+ vec_len = len(elems)
+ # Reserve a vector slot for the unknown token at the
+ # very beggining because the unknown index is 0.
+ all_elems.extend([0] * vec_len)
+ else:
+ assert len(elems) == vec_len, \
+ 'At line %d of the pre-trained token embedding ' \
+ 'file: the dimension of token %s is %d but the ' \
+ 'dimension of previous tokens is %d. Dimensions ' \
+ 'of all the tokens must be the same.' \
+ % (line_num, token, len(elems), vec_len)
+ all_elems.extend(elems)
+ self._idx_to_token.append(token)
+ self._token_to_idx[token] = len(self._idx_to_token) - 1
+ tokens.add(token)
+
+ self._vec_len = vec_len
+ self._idx_to_vec = nd.array(all_elems).reshape((-1, self.vec_len))
+
+ if loaded_unknown_vec is None:
+ self._idx_to_vec[C.UNKNOWN_IDX] = init_unknown_vec(
+ shape=self.vec_len)
+ else:
+ self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec)
+
+ @property
+ def vec_len(self):
+ return self._vec_len
+
+ @property
+ def idx_to_vec(self):
+ return self._idx_to_vec
+
+ def get_vecs_by_tokens(self, tokens, lower_case_backup=False):
+ """Look up embedding vectors of tokens.
+
+
+ Parameters
+ ----------
+ tokens : str or list of strs
+ A token or a list of tokens.
+ lower_case_backup : bool, default False
+ If False, each token in the original case will be looked up; if
+ True, each token in the original case will be looked up first, if
+ not found in the keys of the property `token_to_idx`, the token
+ in the lower case will be looked up.
+
+
+ Returns
+ -------
+ mxnet.ndarray.NDArray:
+ The embedding vector(s) of the token(s). According to numpy
+ conventions, if `tokens` is a string, returns a 1-D NDArray of shape
+ `self.vec_len`; if `tokens` is a list of strings, returns a 2-D
+ NDArray of shape=(len(tokens), self.vec_len).
+ """
+
+ to_reduce = False
+ if not isinstance(tokens, list):
+ tokens = [tokens]
+ to_reduce = True
+
+ if not lower_case_backup:
+ indices = [self.token_to_idx.get(token, C.UNKNOWN_IDX)
+ for token in tokens]
+ else:
+ indices = [self.token_to_idx[token] if token in self.token_to_idx
+ else self.token_to_idx.get(token.lower(), C.UNKNOWN_IDX)
+ for token in tokens]
+
+ vecs = nd.Embedding(nd.array(indices), self.idx_to_vec,
+ self.idx_to_vec.shape[0], self.idx_to_vec.shape[1])
+
+ return vecs[0] if to_reduce else vecs
+
+ def update_token_vectors(self, tokens, new_vectors):
+ """Updates embedding vectors for tokens.
+
+
+ Parameters
+ ----------
+ tokens : str or a list of strs
+ A token or a list of tokens whose embedding vector are to be
+ updated.
+ new_vectors : mxnet.ndarray.NDArray
+ An NDArray to be assigned to the embedding vectors of `tokens`.
+ Its length must be equal to the number of `tokens` and its width
+ must be equal to the dimension of embeddings of the glossary. If
+ `tokens` is a singleton, it must be 1-D or 2-D. If `tokens` is a
+ list of multiple strings, it must be 2-D.
+ """
+
+ assert self.idx_to_vec is not None, \
+ 'The property `idx_to_vec` has not been properly set.'
+
+ if not isinstance(tokens, list) or len(tokens) == 1:
+ assert isinstance(new_vectors, nd.NDArray) and \
+ len(new_vectors.shape) in [1, 2], \
+ '`new_vectors` must be a 1-D or 2-D NDArray if `tokens` is a ' \
+ 'singleton.'
+ if not isinstance(tokens, list):
+ tokens = [tokens]
+ if len(new_vectors.shape) == 1:
+ new_vectors = new_vectors.expand_dims(0)
+
+ else:
+ assert isinstance(new_vectors, nd.NDArray) and \
+ len(new_vectors.shape) == 2, \
+ '`new_vectors` must be a 2-D NDArray if `tokens` is a list ' \
+ 'of multiple strings.'
+ assert new_vectors.shape == (len(tokens), self.vec_len), \
+ 'The length of new_vectors must be equal to the number of tokens ' \
+ 'and the width of new_vectors must be equal to the dimension of ' \
+ 'embeddings of the glossary.'
+
+ indices = []
+ for token in tokens:
+ if token in self.token_to_idx:
+ indices.append(self.token_to_idx[token])
+ else:
+ raise ValueError('Token %s is unknown. To update the embedding '
+ 'vector for an unknown token, please specify '
+ 'it explicitly as the `unknown_token` %s in '
+ '`tokens`. This is to avoid unintended '
+ 'updates.' %
+ (token, self.idx_to_token[C.UNKNOWN_IDX]))
+
+ self._idx_to_vec[nd.array(indices)] = new_vectors
+
+ @staticmethod
+ def register(embedding_cls):
+ """Registers a new token embedding.
+
+
+ Once an embedding is registered, we can create an instance of this
+ embedding with :func:`~mxnet.text.embedding.TokenEmbedding.create`.
+
+
+ Examples
+ --------
+ >>> @mxnet.text.embedding.TokenEmbedding.register
+ ... class MyTextEmbed(mxnet.text.embedding.TokenEmbedding):
+ ... def __init__(self, pretrained_file_name='my_pretrain_file'):
+ ... pass
+ >>> embed = mxnet.text.embedding.TokenEmbedding.create('MyTokenEmbed')
+ >>> print(type(embed))
+ <class '__main__.MyTokenEmbed'>
+ """
+
+ register_text_embedding = registry.get_register_func(
+ TokenEmbedding, 'token embedding')
+ return register_text_embedding(embedding_cls)
+
+ @staticmethod
+ def create(embedding_name, **kwargs):
+ """Creates an instance of :class:`~mxnet.text.embedding.TokenEmbedding`.
+
+
+ Creates a token embedding instance by loading embedding vectors from an
+ externally hosted pre-trained token embedding file, such as those
+ of GloVe and FastText. To get all the valid `embedding_name` and
+ `pretrained_file_name`, use `mxnet.text.embedding.TokenEmbedding.
+ get_embedding_and_pretrained_file_names()`.
+
+
+ Parameters
+ ----------
+ embedding_name : str
+ The token embedding name (case-insensitive).
+
+
+ Returns
+ -------
+ :class:`~mxnet.text.glossary.TokenEmbedding`:
+ A token embedding instance that loads embedding vectors from an
+ externally hosted pre-trained token embedding file.
+ """
+
+ create_text_embedding = registry.get_create_func(
+ TokenEmbedding, 'token embedding')
+ return create_text_embedding(embedding_name, **kwargs)
+
+ @classmethod
+ def _check_pretrained_file_names(cls, pretrained_file_name):
+ """Checks if a pre-trained token embedding file name is valid.
+
+
+ Parameters
+ ----------
+ pretrained_file_name : str
+ The pre-trained token embedding file.
+ """
+
+ embedding_name = cls.__name__.lower()
+ if pretrained_file_name not in cls.pretrained_file_name_sha1:
+ raise KeyError('Cannot find pretrained file %s for token embedding '
+ '%s. Valid pretrained files for embedding %s: %s' %
+ (pretrained_file_name, embedding_name,
+ embedding_name,
+ ', '.join(cls.pretrained_file_name_sha1.keys())))
+
+ @staticmethod
+ def get_embedding_and_pretrained_file_names(embedding_name=None):
+ """Get valid token embedding names and their pre-trained file names.
+
+
+ To load token embedding vectors from an externally hosted pre-trained
+ token embedding file, such as those of GloVe and FastText, one should
+ use `mxnet.text.embedding.TokenEmbedding.create(embedding_name,
+ pretrained_file_name)`. This method returns all the valid names of
+ `pretrained_file_name` for the specified `embedding_name`. If
+ `embedding_name` is set to None, this method returns all the valid names
+ of `embedding_name` with associated `pretrained_file_name`.
+
+
+ Parameters
+ ----------
+ embedding_name : str or None, default None
+ The pre-trained token embedding name.
+
+
+ Returns
+ -------
+ dict or list:
+ A list of all the valid pre-trained token embedding file names
+ (`pretrained_file_name`) for the specified token embedding name
+ (`embedding_name`). If the text embeding name is set to None,
+ returns a dict mapping each valid token embedding name to a list
+ of valid pre-trained files (`pretrained_file_name`). They can be
+ plugged into `mxnet.text.embedding.TokenEmbedding.create(
+ embedding_name, pretrained_file_name)`.
+ """
+
+ text_embedding_reg = registry.get_registry(TokenEmbedding)
+
+ if embedding_name is not None:
+ if embedding_name not in text_embedding_reg:
+ raise KeyError('Cannot find `embedding_name` %s. Use '
+ '`get_embedding_and_pretrained_file_names('
+ 'embedding_name=None).keys()` to get all the '
+ 'valid embedding names.' % embedding_name)
+ return list(text_embedding_reg[
+ embedding_name].pretrained_file_name_sha1.keys())
+ else:
+ return {embedding_name: list(
+ embedding_cls.pretrained_file_name_sha1.keys())
+ for embedding_name, embedding_cls in
+ registry.get_registry(TokenEmbedding).items()}
+
+
+@TokenEmbedding.register
+class GloVe(TokenEmbedding):
+ """The GloVe word embedding.
+
+
+ GloVe is an unsupervised learning algorithm for obtaining vector
+ representations for words. Training is performed on aggregated global
+ word-word co-occurrence statistics from a corpus, and the resulting
+ representations showcase interesting linear substructures of the word vector
+ space. (Source from https://nlp.stanford.edu/projects/glove/)
+
+ Reference:
+
+ GloVe: Global Vectors for Word Representation.
+ Jeffrey Pennington, Richard Socher, and Christopher D. Manning.
+ https://nlp.stanford.edu/pubs/glove.pdf
+
+ Website:
+
+ https://nlp.stanford.edu/projects/glove/
+
+ To get the updated URLs to the externally hosted pre-trained token embedding
+ files, visit https://nlp.stanford.edu/projects/glove/
+
+
+ Parameters
+ ----------
+ pretrain_file : str, default 'glove.840B.300d.txt'
+ The name of the pre-trained token embedding file.
+ embed_root : str, default os.path.join('~', '.mxnet', 'embeddings')
+ The root directory for storing embedding-related files.
+ unknown_vec : callback
+ The callback used to initialize the embedding vector for the unknown
+ token.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ vec_len : int
+ The length of the embedding vector for each token.
+ idx_to_vec : mxnet.ndarray.NDArray
+ For all the indexed tokens in this embedding, this NDArray maps each
+ token's index to an embedding vector. The largest valid index maps
+ to the initialized embedding vector for every reserved token, such as an
+ unknown_token token and a padding token.
+ """
+
+ # Map a pre-trained token embedding archive file and its SHA-1 hash.
+ pretrained_archive_name_sha1 = \
+ {'glove.42B.300d.zip': 'f8e722b39578f776927465b71b231bae2ae8776a',
+ 'glove.6B.zip': 'b64e54f1877d2f735bdd000c1d7d771e25c7dfdc',
+ 'glove.840B.300d.zip': '8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d',
+ 'glove.twitter.27B.zip': 'dce69c404025a8312c323197347695e81fd529fc'}
+
+ # Map a pre-trained token embedding file and its SHA-1 hash.
+ pretrained_file_name_sha1 = \
+ {'glove.42B.300d.txt': '876767977d6bd4d947c0f84d44510677bc94612a',
+ 'glove.6B.50d.txt': '21bf566a9d27f84d253e0cd4d4be9dcc07976a6d',
+ 'glove.6B.100d.txt': '16b1dbfaf35476790bd9df40c83e2dfbd05312f1',
+ 'glove.6B.200d.txt': '17d0355ddaa253e298ede39877d1be70f99d9148',
+ 'glove.6B.300d.txt': '646443dd885090927f8215ecf7a677e9f703858d',
+ 'glove.840B.300d.txt': '294b9f37fa64cce31f9ebb409c266fc379527708',
+ 'glove.twitter.27B.25d.txt':
+ '767d80889d8c8a22ae7cd25e09d0650a6ff0a502',
+ 'glove.twitter.27B.50d.txt':
+ '9585f4be97e286339bf0112d0d3aa7c15a3e864d',
+ 'glove.twitter.27B.100d.txt':
+ '1bbeab8323c72332bd46ada0fc3c99f2faaa8ca8',
+ 'glove.twitter.27B.200d.txt':
+ '7921c77a53aa5977b1d9ce3a7c4430cbd9d1207a'}
+
+ url_prefix = 'http://nlp.stanford.edu/data/'
+
+ def __init__(self, pretrained_file_name='glove.840B.300d.txt',
+ embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
+ init_unknown_vec=nd.zeros, **kwargs):
+ GloVe._check_pretrained_file_names(pretrained_file_name)
+ src_archive = {archive.split('.')[1]: archive for archive in
+ GloVe.pretrained_archive_name_sha1.keys()}
+ archive = src_archive[pretrained_file_name.split('.')[1]]
+ url = GloVe.url_prefix + archive
+
+ super(GloVe, self).__init__(**kwargs)
+
+ pretrained_file_path = GloVe._get_pretrained_file_path_from_url(
+ url, embedding_root, pretrained_file_name)
+
+ self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
+
+
+@TokenEmbedding.register
+class FastText(TokenEmbedding):
+ """The fastText word embedding.
+
+
+ FastText is an open-source, free, lightweight library that allows users to
+ learn text representations and text classifiers. It works on standard,
+ generic hardware. Models can later be reduced in size to even fit on mobile
+ devices. (Source from https://fasttext.cc/)
+
+ References:
+
+ Enriching Word Vectors with Subword Information.
+ Piotr Bojanowski, Edouard Grave, Armand Joulin, and Tomas Mikolov.
+ https://arxiv.org/abs/1607.04606
+
+ Bag of Tricks for Efficient Text Classification.
+ Armand Joulin, Edouard Grave, Piotr Bojanowski, and Tomas Mikolov.
+ https://arxiv.org/abs/1607.01759
+
+ FastText.zip: Compressing text classification models.
+ Armand Joulin, Edouard Grave, Piotr Bojanowski, Matthijs Douze, Herve Jegou,
+ and Tomas Mikolov.
+ https://arxiv.org/abs/1612.03651
+
+ Website:
+
+ https://fasttext.cc/
+
+ To get the updated URLs to the externally hosted pre-trained token embedding
+ files, visit
+ https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md
+
+
+ Parameters
+ ----------
+ pretrain_file : str, default 'wiki.en.vec'
+ The name of the pre-trained token embedding file.
+ embed_root : str, default os.path.join('~', '.mxnet', 'embeddings')
+ The root directory for storing embedding-related files.
+ unknown_vec : callback
+ The callback used to initialize the embedding vector for the unknown
+ token.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ vec_len : int
+ The length of the embedding vector for each token.
+ idx_to_vec : mxnet.ndarray.NDArray
+ For all the indexed tokens in this embedding, this NDArray maps each
+ token's index to an embedding vector. The largest valid index maps
+ to the initialized embedding vector for every reserved token, such as an
+ unknown_token token and a padding token.
+ """
+
+ # Map a pre-trained token embedding file and its SHA-1 hash.
+ pretrained_file_name_sha1 = \
+ {'wiki.en.vec': 'c1e418f144ceb332b4328d27addf508731fa87df',
+ 'wiki.simple.vec': '55267c50fbdf4e4ae0fbbda5c73830a379d68795',
+ 'wiki.zh.vec': '117ab34faa80e381641fbabf3a24bc8cfba44050'}
+ url_prefix = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/'
+
+ def __init__(self, pretrained_file_name='wiki.en.vec',
+ embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
+ init_unknown_vec=nd.zeros, **kwargs):
+ FastText._check_pretrained_file_names(pretrained_file_name)
+ url = FastText.url_prefix + pretrained_file_name
+
+ super(FastText, self).__init__(**kwargs)
+
+ pretrained_file_path = FastText._get_pretrained_file_path_from_url(
+ url, embedding_root, pretrained_file_name)
+
+ self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
+
+
+class CustomEmbedding(TokenEmbedding):
+ """User-defined token embedding.
+
+ This is to load embedding vectors from a user-defined pre-trained text
+ embedding file.
+
+ Denote by '<ed>' the argument `elem_delim`. Denote by <v_ij> the j-th
+ element of the token embedding vector for <token_i>, the expected format of
+ a custom pre-trained token embedding file is:
+
+ '<token_1><ed><v_11><ed><v_12><ed>...<ed><v_1k>\\\\n<token_2><ed><v_21><ed>
+ <v_22><ed>...<ed><v_2k>\\\\n...'
+
+ where k is the length of the embedding vector `vec_len`.
+
+
+ Parameters
+ ----------
+ pretrain_file_path : str
+ The path to the custom pre-trained token embedding file.
+ elem_delim : str, default ' '
+ The delimiter for splitting a token and every embedding vector element
+ value on the same line of the custom pre-trained token embedding file.
+ unknown_vec : callback
+ The callback used to initialize the embedding vector for the unknown
+ token.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ vec_len : int
+ The length of the embedding vector for each token.
+ idx_to_vec : mxnet.ndarray.NDArray
+ For all the indexed tokens in this embedding, this NDArray maps each
+ token's index to an embedding vector. The largest valid index maps
+ to the initialized embedding vector for every reserved token, such as an
+ unknown_token token and a padding token.
+ """
+
+ def __init__(self, pretrained_file_path, elem_delim=' ', encoding='utf8',
+ init_unknown_vec=nd.zeros, **kwargs):
+ super(CustomEmbedding, self).__init__(**kwargs)
+ self._load_embedding(pretrained_file_path, elem_delim, init_unknown_vec,
+ encoding)
diff --git a/python/mxnet/text/glossary.py b/python/mxnet/text/glossary.py
new file mode 100644
index 0000000..941732e
--- /dev/null
+++ b/python/mxnet/text/glossary.py
@@ -0,0 +1,142 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+"""Index text tokens and load their embeddings."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+from .. import ndarray as nd
+from .embedding import TokenEmbedding
+
+
+class Glossary(TokenEmbedding):
+ """Indexing and embedding for text tokens in a glossary.
+
+
+ For each indexed token in a glossary, an embedding vector will be associated
+ with it. Such embedding vectors can be loaded from externally hosted or
+ custom pre-trained token embedding files, such as via instances of
+ :class:`~mxnet.text.embedding.TokenEmbedding`.
+
+
+ Parameters
+ ----------
+ counter : collections.Counter or None, default None
+ Counts text token frequencies in the text data. Its keys will be indexed
+ according to frequency thresholds such as `most_freq_count` and
+ `min_freq`. Keys of `counter`, `unknown_token`, and values of
+ `reserved_tokens` must be of the same hashable type. Examples: str, int,
+ and tuple.
+ token_embeddings : instance or list of :class:`~TokenEmbedding`
+ One or multiple pre-trained token embeddings to load. If it is a list of
+ multiple embeddings, these embedding vectors will be concatenated for
+ each token.
+ most_freq_count : None or int, default None
+ The maximum possible number of the most frequent tokens in the keys of
+ `counter` that can be indexed. Note that this argument does not count
+ any token from `reserved_tokens`. If this argument is None or larger
+ than its largest possible value restricted by `counter` and
+ `reserved_tokens`, this argument becomes positive infinity.
+ min_freq : int, default 1
+ The minimum frequency required for a token in the keys of `counter` to
+ be indexed.
+ unknown_token : hashable object, default '<unk>'
+ The representation for any unknown token. In other words, any unknown
+ token will be indexed as the same representation. Keys of `counter`,
+ `unknown_token`, and values of `reserved_tokens` must be of the same
+ hashable type. Examples: str, int, and tuple.
+ reserved_tokens : list of hashable objects or None, default None
+ A list of reserved tokens that will always be indexed, such as special
+ symbols representing padding, beginning of sentence, and end of
+ sentence. It cannot contain `unknown_token`, or duplicate reserved
+ tokens. Keys of `counter`, `unknown_token`, and values of
+ `reserved_tokens` must be of the same hashable type. Examples: str, int,
+ and tuple.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ vec_len : int
+ The length of the embedding vector for each token.
+ idx_to_vec : mxnet.ndarray.NDArray
+ For all the indexed tokens in this embedding, this NDArray maps each
+ token's index to an embedding vector. The largest valid index maps
+ to the initialized embedding vector for every reserved token, such as an
+ unknown_token token and a padding token.
+ """
+ def __init__(self, counter, token_embeddings, most_freq_count=None,
+ min_freq=1, unknown_token='<unk>', reserved_tokens=None):
+
+ if not isinstance(token_embeddings, list):
+ token_embeddings = [token_embeddings]
+
+ # Sanity checks.
+ for embed in token_embeddings:
+ assert isinstance(embed, TokenEmbedding), \
+ 'The parameter `token_embeddings` must be an instance or a ' \
+ 'list of instances of `mxnet.text.embedding.TextEmbed` ' \
+ 'whose embedding vectors will be loaded or ' \
+ 'concatenated-then-loaded to map to the indexed tokens.'
+
+ # Index tokens from keys of `counter` and reserved tokens.
+ super(Glossary, self).__init__(counter=counter,
+ most_freq_count=most_freq_count,
+ min_freq=min_freq,
+ unknown_token=unknown_token,
+ reserved_tokens=reserved_tokens)
+
+ # Set _idx_to_vec so that indices of tokens from keys of `counter` are
+ # associated with token embedding vectors from `token_embeddings`.
+ self._set_idx_to_vec_by_embeds(token_embeddings)
+
+ def _set_idx_to_vec_by_embeds(self, token_embeddings):
+ """Sets the mapping between token indices and token embedding vectors.
+
+
+ Parameters
+ ----------
+ token_embeddings : an instance or a list of instances of
+ :class:`~mxnet.text.embedding.TokenEmbedding`
+ One or multiple pre-trained token embeddings to load. If it is a
+ list of multiple embeddings, these embedding vectors will be
+ concatenated for each token.
+ """
+
+ self._vec_len = sum(embed.vec_len for embed in token_embeddings)
+ self._idx_to_vec = nd.zeros(shape=(len(self), self.vec_len))
+
+ col_start = 0
+ # Concatenate all the embedding vectors in token_embeddings.
+ for embed in token_embeddings:
+ col_end = col_start + embed.vec_len
+ # Cancatenate vectors of the unknown token.
+ self._idx_to_vec[0, col_start:col_end] = embed.idx_to_vec[0]
+ self._idx_to_vec[1:, col_start:col_end] = embed.get_vecs_by_tokens(
+ self.idx_to_token[1:])
+ col_start = col_end
diff --git a/python/mxnet/text/indexer.py b/python/mxnet/text/indexer.py
new file mode 100644
index 0000000..bed2794
--- /dev/null
+++ b/python/mxnet/text/indexer.py
@@ -0,0 +1,231 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=consider-iterating-dictionary
+
+"""Text token indexer."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+from collections import Counter
+
+from . import constants as C
+
+
+class TokenIndexer(object):
+ """Indexing for text tokens.
+
+
+ Build indices for the unknown token, reserved tokens, and input counter
+ keys. Indexed tokens can be used by instances of
+ :class:`~mxnet.text.embedding.TokenEmbedding`, such as instances of
+ :class:`~mxnet.text.glossary.Glossary`.
+
+
+ Parameters
+ ----------
+ counter : collections.Counter or None, default None
+ Counts text token frequencies in the text data. Its keys will be indexed
+ according to frequency thresholds such as `most_freq_count` and
+ `min_freq`. Keys of `counter`, `unknown_token`, and values of
+ `reserved_tokens` must be of the same hashable type. Examples: str, int,
+ and tuple.
+ most_freq_count : None or int, default None
+ The maximum possible number of the most frequent tokens in the keys of
+ `counter` that can be indexed. Note that this argument does not count
+ any token from `reserved_tokens`. Suppose that there are different
+ keys of `counter` whose frequency are the same, if indexing all of them
+ will exceed this argument value, such keys will be indexed one by one
+ according to their __cmp__() order until the frequency threshold is
+ met. If this argument is None or larger than its largest possible value
+ restricted by `counter` and `reserved_tokens`, this argument has no
+ effect.
+ min_freq : int, default 1
+ The minimum frequency required for a token in the keys of `counter` to
+ be indexed.
+ unknown_token : hashable object, default '<unk>'
+ The representation for any unknown token. In other words, any unknown
+ token will be indexed as the same representation. Keys of `counter`,
+ `unknown_token`, and values of `reserved_tokens` must be of the same
+ hashable type. Examples: str, int, and tuple.
+ reserved_tokens : list of hashable objects or None, default None
+ A list of reserved tokens that will always be indexed, such as special
+ symbols representing padding, beginning of sentence, and end of
+ sentence. It cannot contain `unknown_token`, or duplicate reserved
+ tokens. Keys of `counter`, `unknown_token`, and values of
+ `reserved_tokens` must be of the same hashable type. Examples: str, int,
+ and tuple.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ """
+
+ def __init__(self, counter=None, most_freq_count=None, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None):
+
+ # Sanity checks.
+ assert min_freq > 0, '`min_freq` must be set to a positive value.'
+
+ if reserved_tokens is not None:
+ reserved_token_set = set(reserved_tokens)
+ assert unknown_token not in reserved_token_set, \
+ '`reserved_token` cannot contain `unknown_token`.'
+ assert len(reserved_token_set) == len(reserved_tokens), \
+ '`reserved_tokens` cannot contain duplicate reserved tokens.'
+
+ self._index_unknown_and_reserved_tokens(unknown_token, reserved_tokens)
+
+ if counter is not None:
+ self._index_counter_keys(counter, unknown_token, reserved_tokens,
+ most_freq_count, min_freq)
+
+ def _index_unknown_and_reserved_tokens(self, unknown_token,
+ reserved_tokens):
+ """Indexes unknown and reserved tokens."""
+
+ self._unknown_token = unknown_token
+ # Thus, constants.UNKNOWN_IDX must be 0.
+ self._idx_to_token = [unknown_token]
+
+ if reserved_tokens is None:
+ self._reserved_tokens = None
+ else:
+ self._reserved_tokens = reserved_tokens[:]
+ self._idx_to_token.extend(reserved_tokens)
+
+ self._token_to_idx = {token: idx for idx, token in
+ enumerate(self._idx_to_token)}
+
+ def _index_counter_keys(self, counter, unknown_token, reserved_tokens,
+ most_freq_count, min_freq):
+ """Indexes keys of `counter`.
+
+
+ Indexes keys of `counter` according to frequency thresholds such as
+ `most_freq_count` and `min_freq`.
+ """
+
+ assert isinstance(counter, Counter), \
+ '`counter` must be an instance of collections.Counter.'
+
+ unknown_and_reserved_tokens = set(reserved_tokens) \
+ if reserved_tokens is not None else set()
+ unknown_and_reserved_tokens.add(unknown_token)
+
+ token_freqs = sorted(counter.items(), key=lambda x: x[0])
+ token_freqs.sort(key=lambda x: x[1], reverse=True)
+
+ token_cap = len(unknown_and_reserved_tokens) + (
+ len(counter) if most_freq_count is None else most_freq_count)
+
+ for token, freq in token_freqs:
+ if freq < min_freq or len(self._idx_to_token) == token_cap:
+ break
+ if token not in unknown_and_reserved_tokens:
+ self._idx_to_token.append(token)
+ self._token_to_idx[token] = len(self._idx_to_token) - 1
+
+ def __len__(self):
+ return len(self.idx_to_token)
+
+ @property
+ def token_to_idx(self):
+ return self._token_to_idx
+
+ @property
+ def idx_to_token(self):
+ return self._idx_to_token
+
+ @property
+ def unknown_token(self):
+ return self._unknown_token
+
+ @property
+ def reserved_tokens(self):
+ return self._reserved_tokens
+
+ def to_indices(self, tokens):
+ """Converts tokens to indices according to the text indexer.
+
+
+ Parameters
+ ----------
+ tokens : str or list of strs
+ A source token or tokens to be converted.
+
+
+ Returns
+ -------
+ int or list of ints
+ A token index or a list of token indices according to the text
+ indexer.
+ """
+
+ to_reduce = False
+ if not isinstance(tokens, list):
+ tokens = [tokens]
+ to_reduce = True
+
+ indices = [self.token_to_idx[token] if token in self.token_to_idx
+ else C.UNKNOWN_IDX for token in tokens]
+
+ return indices[0] if to_reduce else indices
+
+ def to_tokens(self, indices):
+ """Converts token indices to tokens according to the text indexer.
+
+
+ Parameters
+ ----------
+ indices : int or list of ints
+ A source token index or token indices to be converted.
+
+
+ Returns
+ -------
+ str or list of strs
+ A token or a list of tokens according to the text indexer.
+ """
+
+ to_reduce = False
+ if not isinstance(indices, list):
+ indices = [indices]
+ to_reduce = True
+
+ max_idx = len(self.idx_to_token) - 1
+
+ tokens = []
+ for idx in indices:
+ if not isinstance(idx, int) or idx > max_idx:
+ raise ValueError('Token index %d in the provided `indices` is '
+ 'invalid.' % idx)
+ else:
+ tokens.append(self.idx_to_token[idx])
+
+ return tokens[0] if to_reduce else tokens
diff --git a/python/mxnet/text/utils.py b/python/mxnet/text/utils.py
new file mode 100644
index 0000000..91e1b62
--- /dev/null
+++ b/python/mxnet/text/utils.py
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+"""Provide utilities for text data processing."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+from collections import Counter
+import re
+
+
+def count_tokens_from_str(source_str, token_delim=' ', seq_delim='\n',
+ to_lower=False, counter_to_update=None):
+ """Counts tokens in the specified string.
+
+ For token_delim='<td>' and seq_delim='<sd>', a specified string of two
+ sequences of tokens may look like::
+
+ <td>token1<td>token2<td>token3<td><sd><td>token4<td>token5<td><sd>
+
+
+ Parameters
+ ----------
+ source_str : str
+ A source string of tokens.
+ token_delim : str, default ' '
+ A token delimiter.
+ seq_delim : str, default '\\\\n'
+ A sequence delimiter.
+ to_lower : bool, default False
+ Whether to convert the source source_str to the lower case.
+ counter_to_update : collections.Counter or None, default None
+ The collections.Counter instance to be updated with the token counts
+ of `source_str`. If None, return a new collections.Counter instance
+ counting tokens from `source_str`.
+
+
+ Returns
+ -------
+ collections.Counter
+ The `counter_to_update` collections.Counter instance after being updated
+ with the token counts of `source_str`. If `counter_to_update` is None,
+ return a new collections.Counter instance counting tokens from
+ `source_str`.
+
+
+ Examples
+ --------
+ >>> source_str = ' Life is great ! \\n life is good . \\n'
+ >>> count_tokens_from_str(token_line, ' ', '\\n', True)
+ Counter({'!': 1, '.': 1, 'good': 1, 'great': 1, 'is': 2, 'life': 2})
+ """
+
+ source_str = filter(None,
+ re.split(token_delim + '|' + seq_delim, source_str))
+ if to_lower:
+ source_str = [t.lower() for t in source_str]
+
+ if counter_to_update is None:
+ return Counter(source_str)
+ else:
+ counter_to_update.update(source_str)
+ return counter_to_update
diff --git a/tests/python/unittest/test_text.py b/tests/python/unittest/test_text.py
new file mode 100644
index 0000000..9674304
--- /dev/null
+++ b/tests/python/unittest/test_text.py
@@ -0,0 +1,743 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# 'License'); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+from __future__ import absolute_import
+from __future__ import print_function
+
+from collections import Counter
+import unittest
+
+from common import assertRaises
+from mxnet import ndarray as nd
+from mxnet.test_utils import *
+from mxnet.text import utils
+from mxnet.text.glossary import Glossary
+from mxnet.text.indexer import TokenIndexer
+from mxnet.text.embedding import TokenEmbedding, CustomEmbedding
+
+
+def _get_test_str_of_tokens(token_delim, seq_delim):
+ seq1 = token_delim + token_delim.join(['Life', 'is', 'great', '!']) \
+ + token_delim + seq_delim
+ seq2 = token_delim + token_delim.join(['life', 'is', 'good', '.']) \
+ + token_delim + seq_delim
+ seq3 = token_delim + token_delim.join(['life', "isn't", 'bad', '.']) \
+ + token_delim + seq_delim
+ seqs = seq1 + seq2 + seq3
+ return seqs
+
+
+def _test_count_tokens_from_str_with_delims(token_delim, seq_delim):
+ source_str = _get_test_str_of_tokens(token_delim, seq_delim)
+
+ cnt1 = utils.count_tokens_from_str(source_str, token_delim, seq_delim,
+ to_lower=False)
+ assert cnt1 == Counter(
+ {'is': 2, 'life': 2, '.': 2, 'Life': 1, 'great': 1, '!': 1, 'good': 1,
+ "isn't": 1, 'bad': 1})
+
+ cnt2 = utils.count_tokens_from_str(source_str, token_delim, seq_delim,
+ to_lower=True)
+ assert cnt2 == Counter(
+ {'life': 3, 'is': 2, '.': 2, 'great': 1, '!': 1, 'good': 1,
+ "isn't": 1, 'bad': 1})
+
+ counter_to_update = Counter({'life': 2})
+
+ cnt3 = utils.count_tokens_from_str(
+ source_str, token_delim, seq_delim, to_lower=False,
+ counter_to_update=counter_to_update.copy())
+ assert cnt3 == Counter(
+ {'is': 2, 'life': 4, '.': 2, 'Life': 1, 'great': 1, '!': 1, 'good': 1,
+ "isn't": 1, 'bad': 1})
+
+ cnt4 = utils.count_tokens_from_str(
+ source_str, token_delim, seq_delim, to_lower=True,
+ counter_to_update=counter_to_update.copy())
+ assert cnt4 == Counter(
+ {'life': 5, 'is': 2, '.': 2, 'great': 1, '!': 1, 'good': 1,
+ "isn't": 1, 'bad': 1})
+
+
+def test_count_tokens_from_str():
+ _test_count_tokens_from_str_with_delims(' ', '\n')
+ _test_count_tokens_from_str_with_delims('IS', 'LIFE')
+
+
+def test_tokens_to_indices():
+ counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ indexer = TokenIndexer(counter, most_freq_count=None, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None)
+
+ i1 = indexer.to_indices('c')
+ assert i1 == 1
+
+ i2 = indexer.to_indices(['c'])
+ assert i2 == [1]
+
+ i3 = indexer.to_indices(['<unk>', 'non-exist'])
+ assert i3 == [0, 0]
+
+ i4 = indexer.to_indices(['a', 'non-exist', 'a', 'b'])
+ assert i4 == [3, 0, 3, 2]
+
+
+def test_indices_to_tokens():
+ counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ indexer = TokenIndexer(counter, most_freq_count=None, min_freq=1,
+ unknown_token='<unknown>', reserved_tokens=None)
+
+ i1 = indexer.to_tokens(1)
+ assert i1 == 'c'
+
+ i2 = indexer.to_tokens([1])
+ assert i2 == ['c']
+
+ i3 = indexer.to_tokens([0, 0])
+ assert i3 == ['<unknown>', '<unknown>']
+
+ i4 = indexer.to_tokens([3, 0, 3, 2])
+ assert i4 == ['a', '<unknown>', 'a', 'b']
+
+ assertRaises(ValueError, indexer.to_tokens, 100)
+
+
+def test_glove():
+ glove_6b_50d = TokenEmbedding.create(
+ 'glove', pretrained_file_name='glove.6B.50d.txt')
+
+ assert len(glove_6b_50d) == 400001
+ assert glove_6b_50d.vec_len == 50
+ assert glove_6b_50d.token_to_idx['hi'] == 11084
+ assert glove_6b_50d.idx_to_token[11084] == 'hi'
+
+ first_vec_sum = glove_6b_50d.idx_to_vec[0].sum().asnumpy()[0]
+ assert_almost_equal(first_vec_sum, 0)
+
+ unk_vec_sum = glove_6b_50d.get_vecs_by_tokens(
+ '<un...@unk>').sum().asnumpy()[0]
+ assert_almost_equal(unk_vec_sum, 0)
+
+ unk_vecs_sum = glove_6b_50d.get_vecs_by_tokens(
+ ['<un...@unk>', '<un...@unk>']).sum().asnumpy()[0]
+ assert_almost_equal(unk_vecs_sum, 0)
+
+
+def test_fasttext():
+ fasttext_simple = TokenEmbedding.create(
+ 'fasttext', pretrained_file_name='wiki.simple.vec',
+ init_unknown_vec=nd.ones)
+
+ assert len(fasttext_simple) == 111052
+ assert fasttext_simple.vec_len == 300
+ assert fasttext_simple.token_to_idx['hi'] == 3241
+ assert fasttext_simple.idx_to_token[3241] == 'hi'
+
+ first_vec_sum = fasttext_simple.idx_to_vec[0].sum().asnumpy()[0]
+ assert_almost_equal(first_vec_sum, fasttext_simple.vec_len)
+
+ unk_vec_sum = fasttext_simple.get_vecs_by_tokens(
+ '<un...@unk>').sum().asnumpy()[0]
+ assert_almost_equal(unk_vec_sum, fasttext_simple.vec_len)
+
+ unk_vecs_sum = fasttext_simple.get_vecs_by_tokens(
+ ['<un...@unk>', '<un...@unk>']).sum().asnumpy()[0]
+ assert_almost_equal(unk_vecs_sum, fasttext_simple.vec_len * 2)
+
+
+def _mk_my_pretrain_file(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
+ seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
+ seqs = seq1 + seq2
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def _mk_my_pretrain_file2(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.01', '0.02', '0.03', '0.04',
+ '0.05']) + '\n'
+ seq2 = token_delim.join(['c', '0.06', '0.07', '0.08', '0.09', '0.1']) + '\n'
+ seqs = seq1 + seq2
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def _mk_my_pretrain_file3(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
+ seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
+ seq3 = token_delim.join(['<unk1>', '1.1', '1.2', '1.3', '1.4',
+ '1.5']) + '\n'
+ seqs = seq1 + seq2 + seq3
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def _mk_my_pretrain_file4(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.01', '0.02', '0.03', '0.04',
+ '0.05']) + '\n'
+ seq2 = token_delim.join(['c', '0.06', '0.07', '0.08', '0.09',
+ '0.1']) + '\n'
+ seq3 = token_delim.join(['<unk2>', '0.11', '0.12', '0.13', '0.14',
+ '0.15']) + '\n'
+ seqs = seq1 + seq2 + seq3
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def _mk_my_invalid_pretrain_file(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
+ seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
+ seq3 = token_delim.join(['c']) + '\n'
+ seqs = seq1 + seq2 + seq3
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def _mk_my_invalid_pretrain_file2(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
+ seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
+ seq3 = token_delim.join(['c', '0.6', '0.7', '0.8']) + '\n'
+ seqs = seq1 + seq2 + seq3
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def test_custom_embed():
+ embed_root = '~/.mxnet/embeddings/'
+ embed_name = 'my_embed'
+ elem_delim = '/t'
+ pretrain_file = 'my_pretrain_file.txt'
+
+ _mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file)
+
+ pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
+
+ my_embed = CustomEmbedding(pretrain_file_path, elem_delim)
+
+ assert len(my_embed) == 3
+ assert my_embed.vec_len == 5
+ assert my_embed.token_to_idx['a'] == 1
+ assert my_embed.idx_to_token[1] == 'a'
+
+ first_vec = my_embed.idx_to_vec[0]
+ assert_almost_equal(first_vec.asnumpy(), np.array([0, 0, 0, 0, 0]))
+
+ unk_vec = my_embed.get_vecs_by_tokens('A')
+ assert_almost_equal(unk_vec.asnumpy(), np.array([0, 0, 0, 0, 0]))
+
+ a_vec = my_embed.get_vecs_by_tokens('A', lower_case_backup=True)
+ assert_almost_equal(a_vec.asnumpy(), np.array([0.1, 0.2, 0.3, 0.4, 0.5]))
+
+ unk_vecs = my_embed.get_vecs_by_tokens(['<un...@unk>', '<un...@unk>'])
+ assert_almost_equal(unk_vecs.asnumpy(),
+ np.array([[0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0]]))
+
+ # Test loaded unknown vectors.
+ pretrain_file2 = 'my_pretrain_file2.txt'
+ _mk_my_pretrain_file3(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file2)
+ pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file2)
+ my_embed2 = CustomEmbedding(pretrain_file_path, elem_delim,
+ init_unknown_vec=nd.ones,
+ unknown_token='<unk>')
+ unk_vec2 = my_embed2.get_vecs_by_tokens('<unk>')
+ assert_almost_equal(unk_vec2.asnumpy(), np.array([1, 1, 1, 1, 1]))
+ unk_vec2 = my_embed2.get_vecs_by_tokens('<un...@unk>')
+ assert_almost_equal(unk_vec2.asnumpy(), np.array([1, 1, 1, 1, 1]))
+
+ my_embed3 = CustomEmbedding(pretrain_file_path, elem_delim,
+ init_unknown_vec=nd.ones,
+ unknown_token='<unk1>')
+ unk_vec3 = my_embed3.get_vecs_by_tokens('<unk1>')
+ assert_almost_equal(unk_vec3.asnumpy(), np.array([1.1, 1.2, 1.3, 1.4, 1.5]))
+ unk_vec3 = my_embed3.get_vecs_by_tokens('<un...@unk>')
+ assert_almost_equal(unk_vec3.asnumpy(), np.array([1.1, 1.2, 1.3, 1.4, 1.5]))
+
+ # Test error handling.
+ invalid_pretrain_file = 'invalid_pretrain_file.txt'
+ _mk_my_invalid_pretrain_file(os.path.join(embed_root, embed_name),
+ elem_delim, invalid_pretrain_file)
+ pretrain_file_path = os.path.join(embed_root, embed_name,
+ invalid_pretrain_file)
+ assertRaises(AssertionError, CustomEmbedding, pretrain_file_path,
+ elem_delim)
+
+ invalid_pretrain_file2 = 'invalid_pretrain_file2.txt'
+ _mk_my_invalid_pretrain_file2(os.path.join(embed_root, embed_name),
+ elem_delim, invalid_pretrain_file2)
+ pretrain_file_path = os.path.join(embed_root, embed_name,
+ invalid_pretrain_file2)
+ assertRaises(AssertionError, CustomEmbedding, pretrain_file_path,
+ elem_delim)
+
+
+def test_token_indexer():
+ counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ g1 = TokenIndexer(counter, most_freq_count=None, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g1) == 5
+ assert g1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3,
+ 'some_word$': 4}
+ assert g1.idx_to_token[1] == 'c'
+ assert g1.unknown_token == '<unk>'
+ assert g1.reserved_tokens is None
+
+ g2 = TokenIndexer(counter, most_freq_count=None, min_freq=2,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g2) == 3
+ assert g2.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
+ assert g2.idx_to_token[1] == 'c'
+ assert g2.unknown_token == '<unk>'
+ assert g2.reserved_tokens is None
+
+ g3 = TokenIndexer(counter, most_freq_count=None, min_freq=100,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g3) == 1
+ assert g3.token_to_idx == {'<unk>': 0}
+ assert g3.idx_to_token[0] == '<unk>'
+ assert g3.unknown_token == '<unk>'
+ assert g3.reserved_tokens is None
+
+ g4 = TokenIndexer(counter, most_freq_count=2, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g4) == 3
+ assert g4.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
+ assert g4.idx_to_token[1] == 'c'
+ assert g4.unknown_token == '<unk>'
+ assert g4.reserved_tokens is None
+
+ g5 = TokenIndexer(counter, most_freq_count=3, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g5) == 4
+ assert g5.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3}
+ assert g5.idx_to_token[1] == 'c'
+ assert g5.unknown_token == '<unk>'
+ assert g5.reserved_tokens is None
+
+ g6 = TokenIndexer(counter, most_freq_count=100, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g6) == 5
+ assert g6.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3,
+ 'some_word$': 4}
+ assert g6.idx_to_token[1] == 'c'
+ assert g6.unknown_token == '<unk>'
+ assert g6.reserved_tokens is None
+
+ g7 = TokenIndexer(counter, most_freq_count=1, min_freq=2,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g7) == 2
+ assert g7.token_to_idx == {'<unk>': 0, 'c': 1}
+ assert g7.idx_to_token[1] == 'c'
+ assert g7.unknown_token == '<unk>'
+ assert g7.reserved_tokens is None
+
+ assertRaises(AssertionError, TokenIndexer, counter, most_freq_count=None,
+ min_freq=0, unknown_token='<unknown>',
+ reserved_tokens=['b'])
+
+ assertRaises(AssertionError, TokenIndexer, counter, most_freq_count=None,
+ min_freq=1, unknown_token='<unknown>',
+ reserved_tokens=['b', 'b'])
+
+ assertRaises(AssertionError, TokenIndexer, counter, most_freq_count=None,
+ min_freq=1, unknown_token='<unknown>',
+ reserved_tokens=['b', '<unknown>'])
+
+ g8 = TokenIndexer(counter, most_freq_count=None, min_freq=1,
+ unknown_token='<unknown>', reserved_tokens=['b'])
+ assert len(g8) == 5
+ assert g8.token_to_idx == {'<unknown>': 0, 'b': 1, 'c': 2, 'a': 3,
+ 'some_word$': 4}
+ assert g8.idx_to_token[1] == 'b'
+ assert g8.unknown_token == '<unknown>'
+ assert g8.reserved_tokens == ['b']
+
+ g9 = TokenIndexer(counter, most_freq_count=None, min_freq=2,
+ unknown_token='<unk>', reserved_tokens=['b', 'a'])
+ assert len(g9) == 4
+ assert g9.token_to_idx == {'<unk>': 0, 'b': 1, 'a': 2, 'c': 3}
+ assert g9.idx_to_token[1] == 'b'
+ assert g9.unknown_token == '<unk>'
+ assert g9.reserved_tokens == ['b', 'a']
+
+ g10 = TokenIndexer(counter, most_freq_count=None, min_freq=100,
+ unknown_token='<unk>', reserved_tokens=['b', 'c'])
+ assert len(g10) == 3
+ assert g10.token_to_idx == {'<unk>': 0, 'b': 1, 'c': 2}
+ assert g10.idx_to_token[1] == 'b'
+ assert g10.unknown_token == '<unk>'
+ assert g10.reserved_tokens == ['b', 'c']
+
+ g11 = TokenIndexer(counter, most_freq_count=1, min_freq=2,
+ unknown_token='<unk>', reserved_tokens=['<pad>', 'b'])
+ assert len(g11) == 4
+ assert g11.token_to_idx == {'<unk>': 0, '<pad>': 1, 'b': 2, 'c': 3}
+ assert g11.idx_to_token[1] == '<pad>'
+ assert g11.unknown_token == '<unk>'
+ assert g11.reserved_tokens == ['<pad>', 'b']
+
+ g12 = TokenIndexer(counter, most_freq_count=None, min_freq=2,
+ unknown_token='b', reserved_tokens=['<pad>'])
+ assert len(g12) == 3
+ assert g12.token_to_idx == {'b': 0, '<pad>': 1, 'c': 2}
+ assert g12.idx_to_token[1] == '<pad>'
+ assert g12.unknown_token == 'b'
+ assert g12.reserved_tokens == ['<pad>']
+
+ g13 = TokenIndexer(counter, most_freq_count=None, min_freq=2,
+ unknown_token='a', reserved_tokens=['<pad>'])
+ assert len(g13) == 4
+ assert g13.token_to_idx == {'a': 0, '<pad>': 1, 'c': 2, 'b': 3}
+ assert g13.idx_to_token[1] == '<pad>'
+ assert g13.unknown_token == 'a'
+ assert g13.reserved_tokens == ['<pad>']
+
+ counter_tuple = Counter([('a', 'a'), ('b', 'b'), ('b', 'b'),
+ ('c', 'c'), ('c', 'c'), ('c', 'c'),
+ ('some_word$', 'some_word$')])
+
+ g14 = TokenIndexer(counter_tuple, most_freq_count=None, min_freq=1,
+ unknown_token=('<unk>', '<unk>'), reserved_tokens=None)
+ assert len(g14) == 5
+ assert g14.token_to_idx == {('<unk>', '<unk>'): 0, ('c', 'c'): 1,
+ ('b', 'b'): 2, ('a', 'a'): 3,
+ ('some_word$', 'some_word$'): 4}
+ assert g14.idx_to_token[1] == ('c', 'c')
+ assert g14.unknown_token == ('<unk>', '<unk>')
+ assert g14.reserved_tokens is None
+
+
+def test_glossary_with_one_embed():
+ embed_root = '~/.mxnet/embeddings/'
+ embed_name = 'my_embed'
+ elem_delim = '/t'
+ pretrain_file = 'my_pretrain_file1.txt'
+
+ _mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file)
+
+ pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
+
+ my_embed = CustomEmbedding(pretrain_file_path, elem_delim,
+ init_unknown_vec=nd.ones)
+
+ counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ g1 = Glossary(counter, my_embed, most_freq_count=None, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=['<pad>'])
+
+ assert g1.token_to_idx == {'<unk>': 0, '<pad>': 1, 'c': 2, 'b': 3, 'a': 4,
+ 'some_word$': 5}
+ assert g1.idx_to_token == ['<unk>', '<pad>', 'c', 'b', 'a', 'some_word$']
+
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [0.6, 0.7, 0.8, 0.9, 1],
+ [0.1, 0.2, 0.3, 0.4, 0.5],
+ [1, 1, 1, 1, 1]])
+ )
+
+ assert g1.vec_len == 5
+ assert g1.reserved_tokens == ['<pad>']
+
+ assert_almost_equal(g1.get_vecs_by_tokens('c').asnumpy(),
+ np.array([1, 1, 1, 1, 1])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['c']).asnumpy(),
+ np.array([[1, 1, 1, 1, 1]])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['a', 'not_exist']).asnumpy(),
+ np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
+ [1, 1, 1, 1, 1]])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['a', 'b']).asnumpy(),
+ np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
+ [0.6, 0.7, 0.8, 0.9, 1]])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['A', 'b']).asnumpy(),
+ np.array([[1, 1, 1, 1, 1],
+ [0.6, 0.7, 0.8, 0.9, 1]])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['A', 'b'],
+ lower_case_backup=True).asnumpy(),
+ np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
+ [0.6, 0.7, 0.8, 0.9, 1]])
+ )
+
+ g1.update_token_vectors(['a', 'b'],
+ nd.array([[2, 2, 2, 2, 2],
+ [3, 3, 3, 3, 3]])
+ )
+
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1]])
+ )
+
+ assertRaises(ValueError, g1.update_token_vectors, 'unknown$$$',
+ nd.array([0, 0, 0, 0, 0]))
+
+ assertRaises(AssertionError, g1.update_token_vectors, '<unk>',
+ nd.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))
+
+ assertRaises(AssertionError, g1.update_token_vectors, '<unk>',
+ nd.array([0]))
+
+ g1.update_token_vectors(['<unk>'],
+ nd.array([0, 0, 0, 0, 0])
+ )
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1]])
+ )
+ g1.update_token_vectors(['<unk>'],
+ nd.array([[10, 10, 10, 10, 10]])
+ )
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[10, 10, 10, 10, 10],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1]])
+ )
+ g1.update_token_vectors('<unk>',
+ nd.array([0, 0, 0, 0, 0])
+ )
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1]])
+ )
+ g1.update_token_vectors('<unk>',
+ nd.array([[10, 10, 10, 10, 10]])
+ )
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[10, 10, 10, 10, 10],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1]])
+ )
+
+
+def test_glossary_with_two_embeds():
+ embed_root = '.'
+ embed_name = 'my_embed'
+ elem_delim = '/t'
+ pretrain_file1 = 'my_pretrain_file1.txt'
+ pretrain_file2 = 'my_pretrain_file2.txt'
+
+ _mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file1)
+ _mk_my_pretrain_file2(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file2)
+
+ pretrain_file_path1 = os.path.join(embed_root, embed_name, pretrain_file1)
+ pretrain_file_path2 = os.path.join(embed_root, embed_name, pretrain_file2)
+
+ my_embed1 = CustomEmbedding(pretrain_file_path1, elem_delim,
+ init_unknown_vec=nd.ones)
+ my_embed2 = CustomEmbedding(pretrain_file_path2, elem_delim)
+
+ counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ g1 = Glossary(counter, [my_embed1, my_embed2], most_freq_count=None,
+ min_freq=1, unknown_token='<unk>', reserved_tokens=None)
+
+ assert g1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3,
+ 'some_word$': 4}
+ assert g1.idx_to_token == ['<unk>', 'c', 'b', 'a', 'some_word$']
+
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1],
+ [0.6, 0.7, 0.8, 0.9, 1, 0, 0, 0, 0, 0],
+ [0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.01, 0.02, 0.03, 0.04, 0.05],
+ [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
+ )
+
+ assert g1.vec_len == 10
+ assert g1.reserved_tokens is None
+ assert_almost_equal(g1.get_vecs_by_tokens('c').asnumpy(),
+ np.array([1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['b', 'not_exist']).asnumpy(),
+ np.array([[0.6, 0.7, 0.8, 0.9, 1, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
+ )
+
+ g1.update_token_vectors(['a', 'b'],
+ nd.array([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
+ [3, 3, 3, 3, 3, 3, 3, 3, 3, 3]])
+ )
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1],
+ [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
+ )
+
+ # Test loaded unknown tokens
+ pretrain_file3 = 'my_pretrain_file3.txt'
+ pretrain_file4 = 'my_pretrain_file4.txt'
+
+ _mk_my_pretrain_file3(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file3)
+ _mk_my_pretrain_file4(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file4)
+
+ pretrain_file_path3 = os.path.join(embed_root, embed_name, pretrain_file3)
+ pretrain_file_path4 = os.path.join(embed_root, embed_name, pretrain_file4)
+
+ my_embed3 = CustomEmbedding(pretrain_file_path3, elem_delim,
+ init_unknown_vec=nd.ones,
+ unknown_token='<unk1>')
+ my_embed4 = CustomEmbedding(pretrain_file_path4, elem_delim,
+ unknown_token='<unk2>')
+
+ g2 = Glossary(counter, [my_embed3, my_embed4], most_freq_count=None,
+ min_freq=1, unknown_token='<unk>', reserved_tokens=None)
+ assert_almost_equal(g2.idx_to_vec.asnumpy(),
+ np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.06, 0.07, 0.08, 0.09, 0.1],
+ [0.6, 0.7, 0.8, 0.9, 1,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.01, 0.02, 0.03, 0.04, 0.05],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15]])
+ )
+
+ g3 = Glossary(counter, [my_embed3, my_embed4], most_freq_count=None,
+ min_freq=1, unknown_token='<unk1>', reserved_tokens=None)
+ assert_almost_equal(g3.idx_to_vec.asnumpy(),
+ np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.06, 0.07, 0.08, 0.09, 0.1],
+ [0.6, 0.7, 0.8, 0.9, 1,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.01, 0.02, 0.03, 0.04, 0.05],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15]])
+ )
+
+ g4 = Glossary(counter, [my_embed3, my_embed4], most_freq_count=None,
+ min_freq=1, unknown_token='<unk2>', reserved_tokens=None)
+ assert_almost_equal(g4.idx_to_vec.asnumpy(),
+ np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.06, 0.07, 0.08, 0.09, 0.1],
+ [0.6, 0.7, 0.8, 0.9, 1,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.01, 0.02, 0.03, 0.04, 0.05],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15]])
+ )
+
+ counter2 = Counter(['b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ g5 = Glossary(counter2, [my_embed3, my_embed4], most_freq_count=None,
+ min_freq=1, unknown_token='a', reserved_tokens=None)
+ assert g5.token_to_idx == {'a': 0, 'c': 1, 'b': 2, 'some_word$': 3}
+ assert g5.idx_to_token == ['a', 'c', 'b', 'some_word$']
+ assert_almost_equal(g5.idx_to_vec.asnumpy(),
+ np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.06, 0.07, 0.08, 0.09, 0.1],
+ [0.6, 0.7, 0.8, 0.9, 1,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15]])
+ )
+
+
+def test_get_embedding_names_and_pretrain_files():
+ assert len(TokenEmbedding.get_embedding_and_pretrained_file_names(
+ embedding_name='fasttext')) == 3
+
+ assert len(TokenEmbedding.get_embedding_and_pretrained_file_names(
+ embedding_name='glove')) == 10
+
+ reg = TokenEmbedding.get_embedding_and_pretrained_file_names(
+ embedding_name=None)
+
+ assert len(reg['glove']) == 10
+ assert len(reg['fasttext']) == 3
+
+ assertRaises(KeyError,
+ TokenEmbedding.get_embedding_and_pretrained_file_names,
+ 'unknown$$')
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].