You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/01 23:57:12 UTC

[GitHub] astonzhang commented on a change in pull request #8763: Add mxnet.text APIs

astonzhang commented on a change in pull request #8763: Add mxnet.text APIs
URL: https://github.com/apache/incubator-mxnet/pull/8763#discussion_r159166547
 
 

 ##########
 File path: python/mxnet/text/embedding.py
 ##########
 @@ -0,0 +1,722 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=consider-iterating-dictionary
+
+"""Read text files and load embeddings."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+from collections import Counter
+import io
+import logging
+import os
+import tarfile
+import warnings
+import zipfile
+
+from ..gluon.utils import check_sha1
+from ..gluon.utils import download
+from .. import ndarray as nd
+
+
+class TextIndexer(object):
+    """Indexing for text tokens.
+
+
+    Build indices for the unknown token, reserved tokens, and input counter
+    keys. Indexed tokens can be used by instances of
+    :func:`~mxnet.text.embeddings.TextEmbed`, such as instances of
+    :func:`~mxnet.text.glossary.Glossary`.
+
+
+    Parameters
+    ----------
+    counter : collections.Counter or None, default None
+        Counts text token frequencies in the text data. Its keys will be indexed
+        according to frequency thresholds such as `most_freq_count` and
+        `min_freq`.
+    most_freq_count : None or int, default None
+        The maximum possible number of the most frequent tokens in the keys of
+        `counter` that can be indexed. Note that this argument does not count
+        any token from `reserved_tokens`. If this argument is None or larger
+        than its largest possible value restricted by `counter` and
+        `reserved_tokens`, this argument becomes positive infinity.
+    min_freq : int, default 1
+        The minimum frequency required for a token in the keys of `counter` to
+        be indexed.
+    unknown_token : str, default '<unk>'
+        The string representation for any unknown token. In other words, any
+        unknown token will be indexed as the same string representation. This
+        string representation cannot be any token to be indexed from the keys of
+        `counter` or from `reserved_tokens`.
+    reserved_tokens : list of strs or None, default None
+        A list of reserved tokens that will always be indexed. It cannot contain
+        `unknown_token`, or duplicate reserved tokens.
+
+
+    Properties
+    ----------
+    token_to_idx : dict mapping str to int
+        A dict mapping each token to its index integer.
+    idx_to_token : list of strs
+        A list of indexed tokens where the list indices and the token indices
+        are aligned.
+    unknown_token : str
+        The string representation for any unknown token. In other words, any
+        unknown token will be indexed as the same string representation.
+    reserved_tokens : list of strs or None
+        A list of reserved tokens that will always be indexed.
+    unknown_idx : int
+        The index for `unknown_token`.
+    """
+    def __init__(self, counter=None, most_freq_count=None, min_freq=1,
+                 unknown_token='<unk>', reserved_tokens=None):
+        # Sanity checks.
+        assert min_freq > 0, '`min_freq` must be set to a positive value.'
+
+        if reserved_tokens is not None:
+            for reserved_token in reserved_tokens:
+                assert reserved_token != unknown_token, \
+                    '`reserved_token` cannot contain `unknown_token`.'
+            assert len(set(reserved_tokens)) == len(reserved_tokens), \
+                '`reserved_tokens` cannot contain duplicate reserved tokens.'
+
+        self._index_unknown_and_reserved_tokens(unknown_token, reserved_tokens)
+
+        if counter is not None:
+            self._index_counter_keys(counter, unknown_token, reserved_tokens,
+                                     most_freq_count, min_freq)
+
+    def _index_unknown_and_reserved_tokens(self, unknown_token,
+                                           reserved_tokens):
+        """Indexes unknown and reserved tokens."""
+        self._unknown_token = unknown_token
+        self._idx_to_token = [unknown_token]
+
+        if reserved_tokens is None:
+            self._reserved_tokens = None
+        else:
+            # Python 2 does not support list.copy().
+            self._reserved_tokens = reserved_tokens[:]
+            self._idx_to_token.extend(reserved_tokens)
+
+        self._token_to_idx = {token: idx for idx, token in
+                              enumerate(self._idx_to_token)}
+
+    def _index_counter_keys(self, counter, unknown_token, reserved_tokens,
+                            most_freq_count, min_freq):
+        """Indexes keys of `counter`.
+
+
+        Indexes keys of `counter` according to frequency thresholds such as
+        `most_freq_count` and `min_freq`.
+        """
+        assert isinstance(counter, Counter), \
+            '`counter` must be an instance of collections.Counter.'
+
+        if reserved_tokens is not None:
+            reserved_tokens = set(reserved_tokens)
+        else:
+            reserved_tokens = set()
+
+        token_freqs = sorted(counter.items(), key=lambda x: x[0])
+        token_freqs.sort(key=lambda x: x[1], reverse=True)
+
+        if most_freq_count is None:
+            # 1 is the unknown token count.
+            token_cap = 1 + len(reserved_tokens) + len(counter)
+        else:
+            token_cap = 1 + len(reserved_tokens) + most_freq_count
+
+        for token, freq in token_freqs:
+            if freq < min_freq or len(self._idx_to_token) == token_cap:
+                break
+            assert token != unknown_token, \
+                'Keys of `counter` cannot contain `unknown_token`. Set ' \
+                '`unknown_token` to another string representation.'
+            if token not in reserved_tokens:
+                self._idx_to_token.append(token)
+                self._token_to_idx[token] = len(self._idx_to_token) - 1
+
+    def __len__(self):
+        return len(self.idx_to_token)
+
+    @property
+    def token_to_idx(self):
+        return self._token_to_idx
+
+    @property
+    def idx_to_token(self):
+        return self._idx_to_token
+
+    @property
+    def unknown_token(self):
+        return self._unknown_token
+
+    @property
+    def reserved_tokens(self):
+        return self._reserved_tokens
+
+    @property
+    def unknown_idx(self):
+        return 0
+
+
+class TextEmbed(TextIndexer):
+    """Text embedding base class.
+
+
+    To load text embeddings from an externally hosted pre-trained text embedding
+    file, such as those of GloVe and FastText, use
+    `TextEmbed.create(embed_name, pretrain_file)`. To get all the
+    available `embed_name` and `pretrain_file`, use
+    `TextEmbed.get_embed_names_and_pretrain_files()`.
+
+    Alternatively, to load embedding vectors from a custom pre-trained text
+    embedding file, use :func:`~mxnet.text.embeddings.CustomEmbed`.
+
+    For the same token, its index and embedding vector may vary across different
+    instances of :func:`~mxnet.text.glossary.TextEmbed`.
+
+
+    Properties
+    ----------
+    vec_len : int
+        The length of the embedding vector for each token.
+    idx_to_vec : mxnet.ndarray.NDArray
+        For all the indexed tokens in this embedding, this NDArray maps each
+        token's index to an embedding vector. The largest valid index maps
+        to the initialized embedding vector for every reserved token, such as an
+        unknown_token token and a padding token.
+    """
+
+    # Key-value pairs for text embedding name in lower case and text embedding
+    # class.
+    embed_registry = {}
+
+    def __init__(self, **kwargs):
+        super(TextEmbed, self).__init__(**kwargs)
+
+    @staticmethod
+    def _get_pretrain_file_path_from_url(url, embed_root, embed_name,
+                                         pretrain_file):
+        """Get the local path to the pre-trained text embedding file from url.
+
+
+        The pretrained embedding file will be downloaded from url if it has not
+        been downloaded yet or the existing file fails to match its expected
+        SHA-1 hash.
+        """
+        embed_root = os.path.expanduser(embed_root)
+
+        embed_dir = os.path.join(embed_root, embed_name)
+        pretrain_file_path = os.path.join(embed_dir, pretrain_file)
+        download_file = os.path.basename(url)
+        download_file_path = os.path.join(embed_dir, download_file)
+
+        embed_cls = TextEmbed.embed_registry[embed_name]
+        expected_file_hash = embed_cls.pretrain_file_sha1[pretrain_file]
+
+        if hasattr(embed_cls, 'pretrain_archive_sha1'):
+            expected_download_hash = \
+                embed_cls.pretrain_archive_sha1[download_file]
+        else:
+            expected_download_hash = expected_file_hash
+
+        # The specified pretrained embedding file does not exist or fails to
+        # match its expected SHA-1 hash.
+        if not os.path.isfile(pretrain_file_path) or \
+                not check_sha1(pretrain_file_path, expected_file_hash):
+            # If download_file_path exists and matches
+            # expected_download_hash, there is no need to download.
+            download(url, download_file_path,
+                     sha1_hash=expected_download_hash)
+
+        # If the downloaded file does not match its expected SHA-1 hash,
+        # we do not encourage to load embeddings from it in case that its
+        # data format is changed.
+        assert check_sha1(download_file_path, expected_download_hash), \
+            'The downloaded file %s does not match its expected SHA-1 ' \
+            'hash. This is caused by the changes at the externally ' \
+            'hosted pretrained embedding file(s). Since its data format ' \
+            'may also be changed, it is discouraged to continue to use ' \
+            'mxnet.text.glossary.TextEmbed.create(%s, **kwargs) ' \
+            'to load the pretrained embedding %s. If you still wish to load ' \
+            'the changed embedding file, please specify its path %s via ' \
+            'pretrain_file of mxnet.text.glossary.TextEmbed(). It will be ' \
+            'loaded only if its data format remains unchanged.' % \
+            (download_file_path, embed_name, embed_name, download_file_path)
+
+        ext = os.path.splitext(download_file)[1]
+        if ext == '.zip':
+            with zipfile.ZipFile(download_file_path, 'r') as zf:
+                zf.extractall(embed_dir)
+        elif ext == '.gz':
+            with tarfile.open(download_file_path, 'r:gz') as tar:
+                tar.extractall(path=embed_dir)
+        return pretrain_file_path
+
+    def _load_embedding(self, pretrain_file_path, elem_delim, unknown_vec):
+        """Load embedding vectors from the pre-trained text embedding file.
+
+
+        Index 0 of `self.idx_to_vec` maps to the initialized embedding vector
+        for every unknown token whose string representation is
+        `self.unknown_token`. For duplicate tokens, only the first-encountered
+        embedding vector will be loaded and the rest will be skipped.
+        """
+        pretrain_file_path = os.path.expanduser(pretrain_file_path)
+
+        if not os.path.isfile(pretrain_file_path):
+            raise ValueError('`pretrain_file_path` must be a valid path to '
+                             'the pre-trained text embedding file.')
+
+        with io.open(pretrain_file_path, 'r', encoding='utf8') as f:
+            lines = f.readlines()
+
+        logging.info('Loading pretrained embedding vectors from %s',
+                     pretrain_file_path)
+
+        vec_len = None
+        all_elems = []
+        tokens = set()
+        for line in lines:
+            elems = line.rstrip().split(elem_delim)
+
+            assert len(elems) > 1, 'The data format of the pre-trained ' \
+                                   'text embedding file %s is unexpected.' \
+                                   % pretrain_file_path
+
+            token, elems = elems[0], [float(i) for i in elems[1:]]
+
+            if token == self.unknown_token:
+                raise ValueError('The string representation of the unknown '
+                                 'token `unknown_token` cannot be any token '
+                                 'from the pre-trained text embedding file.')
+
+            if token in tokens:
+                warnings.warn('The embedding vector for token %s has been '
+                              'loaded and a duplicate embedding for the same '
+                              'token is seen and skipped.' % token)
+            else:
+                if len(elems) == 1:
+                    warnings.warn('Token %s with 1-dimensional vector %s is '
+                                  'likely a header and is skipped.' %
+                                  (token, elems))
+                    continue
+                else:
+                    if vec_len is None:
+                        vec_len = len(elems)
+                        # Reserve a vector slot for the unknown token at the
+                        # very beggining because the unknown index is 0.
+                        all_elems.extend([0] * vec_len)
+                    else:
+                        assert len(elems) == vec_len, \
+                            'The dimension of token %s is %d but the ' \
+                            'dimension of previous tokens is %d. Dimensions ' \
+                            'of all the tokens must be the same.' \
+                            % (token, len(elems), vec_len)
+                all_elems.extend(elems)
+                self._idx_to_token.append(token)
+                self._token_to_idx[token] = len(self._idx_to_token) - 1
+                tokens.add(token)
+
+        self._vec_len = vec_len
+        self._idx_to_vec = nd.array(all_elems).reshape((-1, self.vec_len))
+        self._idx_to_vec[self.unknown_idx] = unknown_vec(shape=self.vec_len)
+
+    @property
+    def vec_len(self):
+        return self._vec_len
+
+    @property
+    def idx_to_vec(self):
+        return self._idx_to_vec
+
+    def __getitem__(self, tokens):
+        """The getter.
+
+        Parameters
+        ----------
+        tokens : str or list of strs
+            A token or a list of tokens.
+
+
+        Returns
+        -------
+        mxnet.ndarray.NDArray:
+            The embedding vector(s) of the token(s). According to numpy
+            conventions, if `tokens` is a string, returns a 1-D NDArray of shape
+            `self.vec_len`; if `tokens` is a list of strings, returns a 2-D
+            NDArray of shape=(len(tokens), self.vec_len).
+        """
+        to_reduce = False
+        if not isinstance(tokens, list):
+            tokens = [tokens]
+            to_reduce = True
+
+        indices = [self.token_to_idx[token] if token in self.token_to_idx
+                   else self.unknown_idx for token in tokens]
+
+        vecs = nd.Embedding(nd.array(indices), self.idx_to_vec,
+                            self.idx_to_vec.shape[0], self.idx_to_vec.shape[1])
+
+        return vecs[0] if to_reduce else vecs
+
+    def update_token_vectors(self, tokens, new_vectors):
+        """Updates embedding vectors for tokens.
+
+
+        Parameters
+        ----------
+        tokens : str or a list of strs.
+            A token or a list of tokens whose embedding vector are to be
+            updated.
+        new_vectors : mxnet.ndarray.NDArray
+            An NDArray to be assigned to the embedding vectors of `tokens`.
+            Its length must be equal to the number of `tokens` and its width
+            must be equal to the dimension of embeddings of the glossary. If
+            `tokens` is a singleton, it must be 1-D or 2-D. If `tokens` is a
+            list of multiple strings, it must be 2-D.
+        """
+
+        assert self.idx_to_vec is not None, \
+            'The property `idx_to_vec` has not been properly set.'
+
+        if not isinstance(tokens, list) or \
+                isinstance(tokens, list) and len(tokens) == 1:
+            assert isinstance(new_vectors, nd.NDArray) and \
+                len(new_vectors.shape) in {1, 2}, \
+                '`new_vectors` must be a 1-D or 2-D NDArray if `tokens` is a ' \
+                'singleton.'
+            if not isinstance(tokens, list):
+                tokens = [tokens]
+            if len(new_vectors.shape) == 1:
+                new_vectors = new_vectors.expand_dims(0)
+
+        else:
+            assert isinstance(tokens, list), \
+                '`tokens` must be a string or a list of strings'
+            assert isinstance(new_vectors, nd.NDArray) and \
+                len(new_vectors.shape) == 2, \
+                '`new_vectors` must be a 2-D NDArray if `tokens` is a list ' \
+                'of multiple strings.'
+        assert new_vectors.shape[0] == len(tokens), \
+            'The length of new_vectors must be equal to the number of tokens.'
+        assert new_vectors.shape[1] == self.vec_len, \
+            'The width of new_vectors must be equal to the dimension of ' \
+            'embeddings of the glossary.'
+
+        indices = []
+        for token in tokens:
+            if token in self.token_to_idx:
+                indices.append(self.token_to_idx[token])
+            else:
+                raise ValueError('Token %s is unknown. To update the embedding '
+                                 'vector for an unknown token, please specify '
+                                 'it explicitly as the `unknown_token` %s in '
+                                 '`tokens`. This is to avoid unintended '
+                                 'updates.' %
+                                 (token, self.idx_to_token[self.unknown_idx]))
+
+        self._idx_to_vec[nd.array(indices)] = new_vectors
+
+    @staticmethod
+    def register(embed_cls):
+        """Registers a new text embedding.
+
+        Once an embedding is registered, we can create an instance of this
+        embedding with :func:`~mxnet.text.embedding.TextEmbed.create`.
+
+
+        Examples
+        --------
+        >>> @mxnet.text.embedding.TextEmbed.register
+        ... class MyTextEmbed(mxnet.text.glossary.TextEmbed):
+        ...     def __init__(self, pretrain_file='my_pretrain_file'):
+        ...         pass
+        >>> embed = mxnet.text.embedding.TextEmbed.create('MyTextEmbed')
+        >>> print(type(embed))
+        <class '__main__.MyTextEmbed'>
+        """
+
+        assert(isinstance(embed_cls, type))
+        embed_name = embed_cls.__name__.lower()
+        if embed_name in TextEmbed.embed_registry:
+            warnings.warn('New embedding %s.%s is overriding existing '
+                          'embedding %s.%s', embed_cls.__module__,
+                          embed_cls.__name__,
+                          TextEmbed.embed_registry[embed_name].__module__,
+                          TextEmbed.embed_registry[embed_name].__name__)
+        TextEmbed.embed_registry[embed_name] = embed_cls
+        return embed_cls
+
+    @staticmethod
+    def create(embed_name, **kwargs):
+        """Creates an instance of :func:`~mxnet.text.embedding.TextEmbed`.
+
+        Creates a text embedding instance by loading embedding vectors from an
+        externally hosted pre-trained text embedding file, such as those
+        of GloVe and FastText. To get all the valid `embed_name` and
+        `pretrain_file`, use
+        `mxnet.text.embedding.TextEmbed.get_embed_names_and_pretrain_files()`.
+
+
+        Parameters
+        ----------
+        embed_name : str
+            The text embedding name (case-insensitive).
+
+
+        Returns
+        -------
+        mxnet.text.glossary.TextEmbed:
+            A text embedding instance that loads embedding vectors from an
+            externally hosted pre-trained text embedding file.
+        """
+        if embed_name.lower() in TextEmbed.embed_registry:
+            return TextEmbed.embed_registry[embed_name.lower()](**kwargs)
+        else:
+            raise ValueError('Cannot find embedding %s. Valid embedding '
+                             'names: %s' %
+                             (embed_name,
+                              ', '.join(TextEmbed.embed_registry.keys())))
+
+    @staticmethod
+    def check_pretrain_files(pretrain_file, embed_name):
+        """Checks if a pre-trained text embedding file name is valid.
+
+
+        Parameters
+        ----------
+        pretrain_file : str
+            The pre-trained text embeddibg file.
+        embed_name : str
+            The text embedding name (case-insensitive).
+        """
+        embed_name = embed_name.lower()
+        embed_cls = TextEmbed.embed_registry[embed_name]
+        if pretrain_file not in embed_cls.pretrain_file_sha1:
+            raise KeyError('Cannot find pretrain file %s for embedding %s. '
+                           'Valid pretrain files for embedding %s: %s' %
+                           (pretrain_file, embed_name, embed_name,
+                            ', '.join(embed_cls.pretrain_file_sha1.keys())))
+
+    @staticmethod
+    def get_embed_names_and_pretrain_files():
 
 Review comment:
   Thanks. It returns a string rather than prints the string. Thus I guess "get_" is better than "list_"

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services