You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/15 17:21:42 UTC

[incubator-mxnet] 01/04: gluon language modeling dataset and text token reader (#9986)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch nlp_toolkit
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 929f5ef630f8f9e5cf6bef6b419282932f531fd0
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Thu Mar 8 16:49:51 2018 -0500

    gluon language modeling dataset and text token reader (#9986)
    
    * language modeling dataset and text token reader.
    
    * update
    
    * add padding
    
    * update bos insert
    
    * update doc
---
 example/gluon/word_language_model/train.py         |  44 ++++---
 python/mxnet/gluon/data/__init__.py                |   2 +
 .../gluon/data/{__init__.py => datareader.py}      |  18 ++-
 python/mxnet/gluon/data/{ => text}/__init__.py     |  10 +-
 .../gluon/data/{__init__.py => text/_constants.py} |  12 +-
 python/mxnet/gluon/data/text/base.py               | 103 +++++++++++++++
 python/mxnet/gluon/data/text/lm.py                 | 145 +++++++++++++++++++++
 python/mxnet/gluon/data/text/utils.py              |  73 +++++++++++
 tests/python/unittest/test_gluon_data_text.py      |  50 +++++++
 9 files changed, 420 insertions(+), 37 deletions(-)

diff --git a/example/gluon/word_language_model/train.py b/example/gluon/word_language_model/train.py
index b69fd17..c732393 100644
--- a/example/gluon/word_language_model/train.py
+++ b/example/gluon/word_language_model/train.py
@@ -16,13 +16,13 @@
 # under the License.
 
 import argparse
+import collections
 import time
 import math
 import mxnet as mx
-from mxnet import gluon, autograd
-from mxnet.gluon import contrib
+from mxnet import gluon, autograd, contrib
+from mxnet.gluon import data
 import model
-import data
 
 parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
 parser.add_argument('--model', type=str, default='lstm',
@@ -71,32 +71,40 @@ if args.cuda:
 else:
     context = mx.cpu(0)
 
-train_dataset = contrib.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
-vocab = train_dataset.vocabulary
-val_dataset, test_dataset = [contrib.data.text.WikiText2('./data', segment,
-                                                         vocab=vocab,
-                                                         seq_len=args.bptt)
-                             for segment in ['validation', 'test']]
+train_dataset = data.text.lm.WikiText2('./data', 'train', seq_len=args.bptt,
+                                       eos='<eos>')
+
+def get_frequencies(dataset):
+    return collections.Counter(x for tup in dataset for x in tup[0] if x)
+
+vocab = contrib.text.vocab.Vocabulary(get_frequencies(train_dataset))
+def index_tokens(data, label):
+    return vocab.to_indices(data), vocab.to_indices(label)
+
+val_dataset, test_dataset = [data.text.lm.WikiText2('./data', segment,
+                                                    seq_len=args.bptt,
+                                                    eos='<eos>')
+                             for segment in ['val', 'test']]
 
 nbatch_train = len(train_dataset) // args.batch_size
-train_data = gluon.data.DataLoader(train_dataset,
+train_data = gluon.data.DataLoader(train_dataset.transform(index_tokens),
                                    batch_size=args.batch_size,
-                                   sampler=contrib.data.IntervalSampler(len(train_dataset),
-                                                                        nbatch_train),
+                                   sampler=gluon.contrib.data.IntervalSampler(len(train_dataset),
+                                                                              nbatch_train),
                                    last_batch='discard')
 
 nbatch_val = len(val_dataset) // args.batch_size
-val_data = gluon.data.DataLoader(val_dataset,
+val_data = gluon.data.DataLoader(val_dataset.transform(index_tokens),
                                  batch_size=args.batch_size,
-                                 sampler=contrib.data.IntervalSampler(len(val_dataset),
-                                                                      nbatch_val),
+                                 sampler=gluon.contrib.data.IntervalSampler(len(val_dataset),
+                                                                            nbatch_val),
                                  last_batch='discard')
 
 nbatch_test = len(test_dataset) // args.batch_size
-test_data = gluon.data.DataLoader(test_dataset,
+test_data = gluon.data.DataLoader(test_dataset.transform(index_tokens),
                                   batch_size=args.batch_size,
-                                  sampler=contrib.data.IntervalSampler(len(test_dataset),
-                                                                       nbatch_test),
+                                  sampler=gluon.contrib.data.IntervalSampler(len(test_dataset),
+                                                                             nbatch_test),
                                   last_batch='discard')
 
 
diff --git a/python/mxnet/gluon/data/__init__.py b/python/mxnet/gluon/data/__init__.py
index 23ae3e9..14a0e46 100644
--- a/python/mxnet/gluon/data/__init__.py
+++ b/python/mxnet/gluon/data/__init__.py
@@ -26,3 +26,5 @@ from .sampler import *
 from .dataloader import *
 
 from . import vision
+
+from . import text
diff --git a/python/mxnet/gluon/data/__init__.py b/python/mxnet/gluon/data/datareader.py
similarity index 64%
copy from python/mxnet/gluon/data/__init__.py
copy to python/mxnet/gluon/data/datareader.py
index 23ae3e9..9b94ed4 100644
--- a/python/mxnet/gluon/data/__init__.py
+++ b/python/mxnet/gluon/data/datareader.py
@@ -16,13 +16,19 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=wildcard-import
-"""Dataset utilities."""
+# pylint: disable=
+"""Dataset reader."""
+__all__ = ['DataReader']
 
-from .dataset import *
+class DataReader(object):
+    """Abstract datareader class. Data reader handles I/O and produces raw samples for a dataset.
 
-from .sampler import *
+    Subclasses need to override `read` that returns a Dataset (and optionally `read_iter` that
+    returns an iterable for large files).
+    """
 
-from .dataloader import *
+    def read(self):
+        raise NotImplementedError
 
-from . import vision
+    def read_iter(self):
+        return self.read()
diff --git a/python/mxnet/gluon/data/__init__.py b/python/mxnet/gluon/data/text/__init__.py
similarity index 87%
copy from python/mxnet/gluon/data/__init__.py
copy to python/mxnet/gluon/data/text/__init__.py
index 23ae3e9..5a6c097 100644
--- a/python/mxnet/gluon/data/__init__.py
+++ b/python/mxnet/gluon/data/text/__init__.py
@@ -17,12 +17,8 @@
 
 # coding: utf-8
 # pylint: disable=wildcard-import
-"""Dataset utilities."""
+"""Text utilities."""
 
-from .dataset import *
+from .base import *
 
-from .sampler import *
-
-from .dataloader import *
-
-from . import vision
+from . import lm
diff --git a/python/mxnet/gluon/data/__init__.py b/python/mxnet/gluon/data/text/_constants.py
similarity index 84%
copy from python/mxnet/gluon/data/__init__.py
copy to python/mxnet/gluon/data/text/_constants.py
index 23ae3e9..9aeac09 100644
--- a/python/mxnet/gluon/data/__init__.py
+++ b/python/mxnet/gluon/data/text/_constants.py
@@ -16,13 +16,13 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=wildcard-import
-"""Dataset utilities."""
 
-from .dataset import *
+"""Constants relevant to text processing."""
 
-from .sampler import *
+UNK_TOKEN = '<unk>'
 
-from .dataloader import *
+BOS_TOKEN = '<bos>'
 
-from . import vision
+EOS_TOKEN = '<eos>'
+
+PAD_TOKEN = '<pad>'
diff --git a/python/mxnet/gluon/data/text/base.py b/python/mxnet/gluon/data/text/base.py
new file mode 100644
index 0000000..a9fa25b
--- /dev/null
+++ b/python/mxnet/gluon/data/text/base.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=
+
+"""Base classes for text datasets and readers."""
+
+__all__ = ['WordLanguageReader']
+
+import io
+import os
+
+from ..dataset import SimpleDataset
+from ..datareader import DataReader
+from .utils import flatten_samples, collate, pair
+
+class WordLanguageReader(DataReader):
+    """Text reader that reads a whole corpus and produces a dataset based on provided
+    sample splitter and word tokenizer.
+
+    The returned dataset includes data (current word) and label (next word).
+
+    Parameters
+    ----------
+    filename : str
+        Path to the input text file.
+    encoding : str, default 'utf8'
+        File encoding format.
+    sample_splitter : function, default str.splitlines
+        A function that splits the dataset string into samples.
+    tokenizer : function, default str.split
+        A function that splits each sample string into list of tokens.
+    seq_len : int or None
+        The length of each of the samples. If None, samples are divided according to
+        `sample_splitter` only, and may have variable lengths.
+    bos : str or None, default None
+        The token to add at the begining of each sentence. If None, nothing is added.
+    eos : str or None, default None
+        The token to add at the end of each sentence. If None, nothing is added.
+    pad : str or None, default None
+        The padding token to add at the end of dataset if `seq_len` is specified and the total
+        number of tokens in the corpus don't evenly divide `seq_len`. If pad is None or seq_len
+        is None, no padding is added. Otherwise, padding token is added to the last sample if
+        its length is less than `seq_len`. If `pad` is None and `seq_len` is specified, the last
+        sample is discarded if it's shorter than `seq_len`.
+    """
+    def __init__(self, filename, encoding='utf8', sample_splitter=lambda s: s.splitlines(),
+                 tokenizer=lambda s: s.split(), seq_len=None, bos=None, eos=None, pad=None):
+        self._filename = os.path.expanduser(filename)
+        self._encoding = encoding
+        self._sample_splitter = sample_splitter
+        self._tokenizer = tokenizer
+
+        if bos and eos:
+            def process(s):
+                out = [bos]
+                out.extend(s)
+                out.append(eos)
+                return pair(out)
+        elif bos:
+            def process(s):
+                out = [bos]
+                out.extend(s)
+                return pair(out)
+        elif eos:
+            def process(s):
+                s.append(eos)
+                return pair(s)
+        else:
+            def process(s):
+                return pair(s)
+        self._process = process
+        self._seq_len = seq_len
+        self._pad = pad
+
+    def read(self):
+        with io.open(self._filename, 'r', encoding=self._encoding) as fin:
+            content = fin.read()
+        samples = [s.strip() for s in self._sample_splitter(content)]
+        samples = [self._process(self._tokenizer(s)) for s in samples if s]
+        if self._seq_len:
+            samples = flatten_samples(samples)
+            if self._pad and len(samples) % self._seq_len:
+                pad_len = self._seq_len - len(samples) % self._seq_len
+                samples.extend([self._pad] * pad_len)
+            samples = collate(samples, self._seq_len)
+        samples = [list(zip(*s)) for s in samples]
+        return SimpleDataset(samples)
diff --git a/python/mxnet/gluon/data/text/lm.py b/python/mxnet/gluon/data/text/lm.py
new file mode 100644
index 0000000..63ad004
--- /dev/null
+++ b/python/mxnet/gluon/data/text/lm.py
@@ -0,0 +1,145 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=
+"""Language model datasets."""
+
+__all__ = ['WikiText2', 'WikiText103']
+
+import os
+import zipfile
+import shutil
+
+from ..dataset import SimpleDataset
+from .base import WordLanguageReader
+from . import _constants as C
+from ...utils import download, check_sha1, _get_repo_file_url
+
+
+class _WikiText(SimpleDataset):
+    def __init__(self, root, namespace, seq_len, bos, eos, pad):
+        root = os.path.expanduser(root)
+        if not os.path.isdir(root):
+            os.makedirs(root)
+        self._root = root
+        self._namespace = 'gluon/dataset/{}'.format(namespace)
+        reader = WordLanguageReader(self._get_data(),
+                                    seq_len=seq_len, bos=bos, eos=eos, pad=pad)
+        super(_WikiText, self).__init__(reader.read())
+
+    def _get_data(self):
+        archive_file_name, archive_hash = self._archive_file
+        data_file_name, data_hash = self._data_file[self._segment]
+        root = self._root
+        path = os.path.join(root, data_file_name)
+        if not os.path.exists(path) or not check_sha1(path, data_hash):
+            downloaded_file_path = download(_get_repo_file_url(self._namespace, archive_file_name),
+                                            path=root,
+                                            sha1_hash=archive_hash)
+
+            with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
+                for member in zf.namelist():
+                    filename = os.path.basename(member)
+                    if filename:
+                        dest = os.path.join(root, filename)
+                        with zf.open(member) as source, \
+                             open(dest, "wb") as target:
+                            shutil.copyfileobj(source, target)
+        return path
+
+
+class WikiText2(_WikiText):
+    """WikiText-2 word-level dataset for language modeling, from Salesforce research.
+
+    From
+    https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset
+
+    License: Creative Commons Attribution-ShareAlike
+
+    Parameters
+    ----------
+    root : str, default '~/.mxnet/datasets/wikitext-2'
+        Path to temp folder for storing data.
+    segment : str, default 'train'
+        Dataset segment. Options are 'train', 'val', 'test'.
+    seq_len : int or None, default 35
+        The number of tokens for each sample. If specified, samples are collated by length.
+        If None, each sample is of variable length.
+    bos : str or None, default None
+        The token to add at the begining of each sentence. If None, nothing is added.
+    eos : str or None, default '<eos>'
+        The token to add at the end of each sentence. If None, nothing is added.
+    pad : str or None, default '<pad>'
+        The padding token to add at the end of dataset if `seq_len` is specified and the total
+        number of tokens in the corpus don't evenly divide `seq_len`. If pad is None or seq_len
+        is None, no padding is added. Otherwise, padding token is added to the last sample if
+        its length is less than `seq_len`. If `pad` is None and `seq_len` is specified, the last
+        sample is discarded if it's shorter than `seq_len`.
+    """
+    def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'),
+                 segment='train', seq_len=35, bos=None, eos=C.EOS_TOKEN, pad=C.PAD_TOKEN):
+        self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
+        self._data_file = {'train': ('wiki.train.tokens',
+                                     '863f29c46ef9d167fff4940ec821195882fe29d1'),
+                           'val': ('wiki.valid.tokens',
+                                   '0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422'),
+                           'test': ('wiki.test.tokens',
+                                    'c7b8ce0aa086fb34dab808c5c49224211eb2b172')}
+        self._segment = segment
+        super(WikiText2, self).__init__(root, 'wikitext-2', seq_len, bos, eos, pad)
+
+
+class WikiText103(_WikiText):
+    """WikiText-103 word-level dataset for language modeling, from Salesforce research.
+
+    From
+    https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset
+
+    License: Creative Commons Attribution-ShareAlike
+
+    Parameters
+    ----------
+    root : str, default '~/.mxnet/datasets/wikitext-103'
+        Path to temp folder for storing data.
+    segment : str, default 'train'
+        Dataset segment. Options are 'train', 'val', 'test'.
+    seq_len : int or None, default 35
+        The number of tokens for each sample. If specified, samples are collated by length.
+        If None, each sample is of variable length.
+    bos : str or None, default None
+        The token to add at the begining of each sentence. If None, nothing is added.
+    eos : str or None, default '<eos>'
+        The token to add at the end of each sentence. If None, nothing is added.
+    pad : str or None, default '<pad>'
+        The padding token to add at the end of dataset if `seq_len` is specified and the total
+        number of tokens in the corpus don't evenly divide `seq_len`. If pad is None or seq_len
+        is None, no padding is added. Otherwise, padding token is added to the last sample if
+        its length is less than `seq_len`. If `pad` is None and `seq_len` is specified, the last
+        sample is discarded if it's shorter than `seq_len`.
+    """
+    def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'),
+                 segment='train', seq_len=35, bos=None, eos=C.EOS_TOKEN, pad=C.PAD_TOKEN):
+        self._archive_file = ('wikitext-103-v1.zip', '0aec09a7537b58d4bb65362fee27650eeaba625a')
+        self._data_file = {'train': ('wiki.train.tokens',
+                                     'b7497e2dfe77e72cfef5e3dbc61b7b53712ac211'),
+                           'val': ('wiki.valid.tokens',
+                                   'c326ac59dc587676d58c422eb8a03e119582f92b'),
+                           'test': ('wiki.test.tokens',
+                                    '8a5befc548865cec54ed4273cf87dbbad60d1e47')}
+        self._segment = segment
+        super(WikiText103, self).__init__(root, 'wikitext-103', seq_len, bos, eos, pad)
diff --git a/python/mxnet/gluon/data/text/utils.py b/python/mxnet/gluon/data/text/utils.py
new file mode 100644
index 0000000..68fa519
--- /dev/null
+++ b/python/mxnet/gluon/data/text/utils.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=redefined-builtin
+
+"""Utility functions."""
+
+import sys
+if sys.version_info[0] < 3:
+    from itertools import izip as zip
+
+def flatten_samples(samples):
+    """Flatten list of list of tokens into a single flattened list of tokens.
+
+    Parameters
+    ----------
+    samples : list of list of object
+        List of samples, each of which is a list of tokens.
+
+    Returns
+    -------
+    Flattened list of tokens.
+    """
+    return [token for sample in samples for token in sample if token]
+
+def collate(flat_sample, seq_len):
+    """Collate a flat list of tokens into list of list of tokens, with each
+    inner list's length equal to the specified `seq_len`.
+
+    Parameters
+    ----------
+    flat_sample : list of object
+        A flat list of tokens.
+    seq_len : int
+        The length of each of the samples.
+
+    Returns
+    -------
+    List of samples, each of which has length equal to `seq_len`.
+    """
+    num_samples = len(flat_sample) // seq_len
+    return [flat_sample[i*seq_len:(i+1)*seq_len] for i in range(num_samples)]
+
+def pair(sample):
+    """Produce tuples of tokens from a list of tokens, with current token as the first
+    element and the next token as the second element.
+
+    Parameters
+    ----------
+    sample : list of object
+        A list of tokens.
+
+    Returns
+    -------
+    Generator of tuples, each of which has current token as the first element and the next token
+    as the second element.
+    """
+    return list(zip(sample[:-1], sample[1:]))
diff --git a/tests/python/unittest/test_gluon_data_text.py b/tests/python/unittest/test_gluon_data_text.py
new file mode 100644
index 0000000..a7c083f
--- /dev/null
+++ b/tests/python/unittest/test_gluon_data_text.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import print_function
+import collections
+import mxnet as mx
+from mxnet.gluon import nn, data
+from common import setup_module, with_seed
+
+def get_frequencies(dataset):
+    return collections.Counter(x for tup in dataset for x in tup[0]+tup[1][-1:])
+
+def test_wikitext2():
+    train = data.text.lm.WikiText2(root='data/wikitext-2', segment='train')
+    val = data.text.lm.WikiText2(root='data/wikitext-2', segment='val')
+    test = data.text.lm.WikiText2(root='data/wikitext-2', segment='test')
+    train_freq, val_freq, test_freq = [get_frequencies(x) for x in [train, val, test]]
+    assert len(train) == 58626
+    assert len(train_freq) == 33278
+    assert len(val) == 6112
+    assert len(val_freq) == 13778
+    assert len(test) == 6892
+    assert len(test_freq) == 14144
+    assert test_freq['English'] == 35
+    assert len(train[0][0]) == 35
+    test_no_pad = data.text.lm.WikiText2(root='data/wikitext-2', segment='test', pad=None)
+    assert len(test_no_pad) == 6891
+
+    train_paragraphs = data.text.lm.WikiText2(root='data/wikitext-2', segment='train', seq_len=None)
+    assert len(train_paragraphs) == 23767
+    assert len(train_paragraphs[0][0]) != 35
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.