You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2018/03/08 21:50:01 UTC
[incubator-mxnet] branch nlp_toolkit updated: gluon language
modeling dataset and text token reader (#9986)
This is an automated email from the ASF dual-hosted git repository.
sxjscience pushed a commit to branch nlp_toolkit
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/nlp_toolkit by this push:
new 329acde gluon language modeling dataset and text token reader (#9986)
329acde is described below
commit 329acde5a722f7be44604dd601884592945755e1
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Thu Mar 8 16:49:51 2018 -0500
gluon language modeling dataset and text token reader (#9986)
* language modeling dataset and text token reader.
* update
* add padding
* update bos insert
* update doc
---
example/gluon/word_language_model/train.py | 44 ++++---
python/mxnet/gluon/data/__init__.py | 2 +
.../gluon/data/{__init__.py => datareader.py} | 18 ++-
python/mxnet/gluon/data/{ => text}/__init__.py | 10 +-
.../gluon/data/{__init__.py => text/_constants.py} | 12 +-
python/mxnet/gluon/data/text/base.py | 103 +++++++++++++++
python/mxnet/gluon/data/text/lm.py | 145 +++++++++++++++++++++
python/mxnet/gluon/data/text/utils.py | 73 +++++++++++
tests/python/unittest/test_gluon_data_text.py | 50 +++++++
9 files changed, 420 insertions(+), 37 deletions(-)
diff --git a/example/gluon/word_language_model/train.py b/example/gluon/word_language_model/train.py
index b69fd17..c732393 100644
--- a/example/gluon/word_language_model/train.py
+++ b/example/gluon/word_language_model/train.py
@@ -16,13 +16,13 @@
# under the License.
import argparse
+import collections
import time
import math
import mxnet as mx
-from mxnet import gluon, autograd
-from mxnet.gluon import contrib
+from mxnet import gluon, autograd, contrib
+from mxnet.gluon import data
import model
-import data
parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
parser.add_argument('--model', type=str, default='lstm',
@@ -71,32 +71,40 @@ if args.cuda:
else:
context = mx.cpu(0)
-train_dataset = contrib.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
-vocab = train_dataset.vocabulary
-val_dataset, test_dataset = [contrib.data.text.WikiText2('./data', segment,
- vocab=vocab,
- seq_len=args.bptt)
- for segment in ['validation', 'test']]
+train_dataset = data.text.lm.WikiText2('./data', 'train', seq_len=args.bptt,
+ eos='<eos>')
+
+def get_frequencies(dataset):
+ return collections.Counter(x for tup in dataset for x in tup[0] if x)
+
+vocab = contrib.text.vocab.Vocabulary(get_frequencies(train_dataset))
+def index_tokens(data, label):
+ return vocab.to_indices(data), vocab.to_indices(label)
+
+val_dataset, test_dataset = [data.text.lm.WikiText2('./data', segment,
+ seq_len=args.bptt,
+ eos='<eos>')
+ for segment in ['val', 'test']]
nbatch_train = len(train_dataset) // args.batch_size
-train_data = gluon.data.DataLoader(train_dataset,
+train_data = gluon.data.DataLoader(train_dataset.transform(index_tokens),
batch_size=args.batch_size,
- sampler=contrib.data.IntervalSampler(len(train_dataset),
- nbatch_train),
+ sampler=gluon.contrib.data.IntervalSampler(len(train_dataset),
+ nbatch_train),
last_batch='discard')
nbatch_val = len(val_dataset) // args.batch_size
-val_data = gluon.data.DataLoader(val_dataset,
+val_data = gluon.data.DataLoader(val_dataset.transform(index_tokens),
batch_size=args.batch_size,
- sampler=contrib.data.IntervalSampler(len(val_dataset),
- nbatch_val),
+ sampler=gluon.contrib.data.IntervalSampler(len(val_dataset),
+ nbatch_val),
last_batch='discard')
nbatch_test = len(test_dataset) // args.batch_size
-test_data = gluon.data.DataLoader(test_dataset,
+test_data = gluon.data.DataLoader(test_dataset.transform(index_tokens),
batch_size=args.batch_size,
- sampler=contrib.data.IntervalSampler(len(test_dataset),
- nbatch_test),
+ sampler=gluon.contrib.data.IntervalSampler(len(test_dataset),
+ nbatch_test),
last_batch='discard')
diff --git a/python/mxnet/gluon/data/__init__.py b/python/mxnet/gluon/data/__init__.py
index 23ae3e9..14a0e46 100644
--- a/python/mxnet/gluon/data/__init__.py
+++ b/python/mxnet/gluon/data/__init__.py
@@ -26,3 +26,5 @@ from .sampler import *
from .dataloader import *
from . import vision
+
+from . import text
diff --git a/python/mxnet/gluon/data/__init__.py b/python/mxnet/gluon/data/datareader.py
similarity index 64%
copy from python/mxnet/gluon/data/__init__.py
copy to python/mxnet/gluon/data/datareader.py
index 23ae3e9..9b94ed4 100644
--- a/python/mxnet/gluon/data/__init__.py
+++ b/python/mxnet/gluon/data/datareader.py
@@ -16,13 +16,19 @@
# under the License.
# coding: utf-8
-# pylint: disable=wildcard-import
-"""Dataset utilities."""
+# pylint: disable=
+"""Dataset reader."""
+__all__ = ['DataReader']
-from .dataset import *
+class DataReader(object):
+ """Abstract datareader class. Data reader handles I/O and produces raw samples for a dataset.
-from .sampler import *
+ Subclasses need to override `read` that returns a Dataset (and optionally `read_iter` that
+ returns an iterable for large files).
+ """
-from .dataloader import *
+ def read(self):
+ raise NotImplementedError
-from . import vision
+ def read_iter(self):
+ return self.read()
diff --git a/python/mxnet/gluon/data/__init__.py b/python/mxnet/gluon/data/text/__init__.py
similarity index 87%
copy from python/mxnet/gluon/data/__init__.py
copy to python/mxnet/gluon/data/text/__init__.py
index 23ae3e9..5a6c097 100644
--- a/python/mxnet/gluon/data/__init__.py
+++ b/python/mxnet/gluon/data/text/__init__.py
@@ -17,12 +17,8 @@
# coding: utf-8
# pylint: disable=wildcard-import
-"""Dataset utilities."""
+"""Text utilities."""
-from .dataset import *
+from .base import *
-from .sampler import *
-
-from .dataloader import *
-
-from . import vision
+from . import lm
diff --git a/python/mxnet/gluon/data/__init__.py b/python/mxnet/gluon/data/text/_constants.py
similarity index 84%
copy from python/mxnet/gluon/data/__init__.py
copy to python/mxnet/gluon/data/text/_constants.py
index 23ae3e9..9aeac09 100644
--- a/python/mxnet/gluon/data/__init__.py
+++ b/python/mxnet/gluon/data/text/_constants.py
@@ -16,13 +16,13 @@
# under the License.
# coding: utf-8
-# pylint: disable=wildcard-import
-"""Dataset utilities."""
-from .dataset import *
+"""Constants relevant to text processing."""
-from .sampler import *
+UNK_TOKEN = '<unk>'
-from .dataloader import *
+BOS_TOKEN = '<bos>'
-from . import vision
+EOS_TOKEN = '<eos>'
+
+PAD_TOKEN = '<pad>'
diff --git a/python/mxnet/gluon/data/text/base.py b/python/mxnet/gluon/data/text/base.py
new file mode 100644
index 0000000..a9fa25b
--- /dev/null
+++ b/python/mxnet/gluon/data/text/base.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=
+
+"""Base classes for text datasets and readers."""
+
+__all__ = ['WordLanguageReader']
+
+import io
+import os
+
+from ..dataset import SimpleDataset
+from ..datareader import DataReader
+from .utils import flatten_samples, collate, pair
+
+class WordLanguageReader(DataReader):
+ """Text reader that reads a whole corpus and produces a dataset based on provided
+ sample splitter and word tokenizer.
+
+ The returned dataset includes data (current word) and label (next word).
+
+ Parameters
+ ----------
+ filename : str
+ Path to the input text file.
+ encoding : str, default 'utf8'
+ File encoding format.
+ sample_splitter : function, default str.splitlines
+ A function that splits the dataset string into samples.
+ tokenizer : function, default str.split
+ A function that splits each sample string into list of tokens.
+ seq_len : int or None
+ The length of each of the samples. If None, samples are divided according to
+ `sample_splitter` only, and may have variable lengths.
+ bos : str or None, default None
+ The token to add at the begining of each sentence. If None, nothing is added.
+ eos : str or None, default None
+ The token to add at the end of each sentence. If None, nothing is added.
+ pad : str or None, default None
+ The padding token to add at the end of dataset if `seq_len` is specified and the total
+ number of tokens in the corpus don't evenly divide `seq_len`. If pad is None or seq_len
+ is None, no padding is added. Otherwise, padding token is added to the last sample if
+ its length is less than `seq_len`. If `pad` is None and `seq_len` is specified, the last
+ sample is discarded if it's shorter than `seq_len`.
+ """
+ def __init__(self, filename, encoding='utf8', sample_splitter=lambda s: s.splitlines(),
+ tokenizer=lambda s: s.split(), seq_len=None, bos=None, eos=None, pad=None):
+ self._filename = os.path.expanduser(filename)
+ self._encoding = encoding
+ self._sample_splitter = sample_splitter
+ self._tokenizer = tokenizer
+
+ if bos and eos:
+ def process(s):
+ out = [bos]
+ out.extend(s)
+ out.append(eos)
+ return pair(out)
+ elif bos:
+ def process(s):
+ out = [bos]
+ out.extend(s)
+ return pair(out)
+ elif eos:
+ def process(s):
+ s.append(eos)
+ return pair(s)
+ else:
+ def process(s):
+ return pair(s)
+ self._process = process
+ self._seq_len = seq_len
+ self._pad = pad
+
+ def read(self):
+ with io.open(self._filename, 'r', encoding=self._encoding) as fin:
+ content = fin.read()
+ samples = [s.strip() for s in self._sample_splitter(content)]
+ samples = [self._process(self._tokenizer(s)) for s in samples if s]
+ if self._seq_len:
+ samples = flatten_samples(samples)
+ if self._pad and len(samples) % self._seq_len:
+ pad_len = self._seq_len - len(samples) % self._seq_len
+ samples.extend([self._pad] * pad_len)
+ samples = collate(samples, self._seq_len)
+ samples = [list(zip(*s)) for s in samples]
+ return SimpleDataset(samples)
diff --git a/python/mxnet/gluon/data/text/lm.py b/python/mxnet/gluon/data/text/lm.py
new file mode 100644
index 0000000..1511d9e
--- /dev/null
+++ b/python/mxnet/gluon/data/text/lm.py
@@ -0,0 +1,145 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=
+"""Language model datasets."""
+
+__all__ = ['WikiText2', 'WikiText103']
+
+import os
+import zipfile
+import shutil
+
+from ..dataset import SimpleDataset
+from .base import WordLanguageReader
+from . import _constants as C
+from ...utils import download, check_sha1, _get_repo_file_url
+
+
+class _WikiText(SimpleDataset):
+ def __init__(self, root, seq_len, bos, eos, pad):
+ self._root = root
+ if not os.path.isdir(root):
+ os.makedirs(root)
+ reader = WordLanguageReader(self._get_data(),
+ seq_len=seq_len, bos=bos, eos=eos, pad=pad)
+ super(_WikiText, self).__init__(reader.read())
+
+ def _get_data(self):
+ archive_file_name, archive_hash = self._archive_file
+ data_file_name, data_hash = self._data_file[self._segment]
+ root = self._root
+ path = os.path.join(root, data_file_name)
+ if not os.path.exists(path) or not check_sha1(path, data_hash):
+ downloaded_file_path = download(_get_repo_file_url(self._namespace, archive_file_name),
+ path=root,
+ sha1_hash=archive_hash)
+
+ with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
+ for member in zf.namelist():
+ filename = os.path.basename(member)
+ if filename:
+ dest = os.path.join(root, filename)
+ with zf.open(member) as source, \
+ open(dest, "wb") as target:
+ shutil.copyfileobj(source, target)
+ return path
+
+
+class WikiText2(_WikiText):
+ """WikiText-2 word-level dataset for language modeling, from Salesforce research.
+
+ From
+ https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset
+
+ License: Creative Commons Attribution-ShareAlike
+
+ Parameters
+ ----------
+ root : str, default '~/.mxnet/datasets/wikitext-2'
+ Path to temp folder for storing data.
+ segment : str, default 'train'
+ Dataset segment. Options are 'train', 'val', 'test'.
+ seq_len : int or None, default 35
+ The number of tokens for each sample. If specified, samples are collated by length.
+ If None, each sample is of variable length.
+ bos : str or None, default None
+ The token to add at the begining of each sentence. If None, nothing is added.
+ eos : str or None, default '<eos>'
+ The token to add at the end of each sentence. If None, nothing is added.
+ pad : str or None, default '<pad>'
+ The padding token to add at the end of dataset if `seq_len` is specified and the total
+ number of tokens in the corpus don't evenly divide `seq_len`. If pad is None or seq_len
+ is None, no padding is added. Otherwise, padding token is added to the last sample if
+ its length is less than `seq_len`. If `pad` is None and `seq_len` is specified, the last
+ sample is discarded if it's shorter than `seq_len`.
+ """
+ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'),
+ segment='train', seq_len=35, bos=None, eos=C.EOS_TOKEN, pad=C.PAD_TOKEN):
+ self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
+ self._data_file = {'train': ('wiki.train.tokens',
+ '863f29c46ef9d167fff4940ec821195882fe29d1'),
+ 'val': ('wiki.valid.tokens',
+ '0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422'),
+ 'test': ('wiki.test.tokens',
+ 'c7b8ce0aa086fb34dab808c5c49224211eb2b172')}
+ self._namespace = 'wikitext-2'
+ self._segment = segment
+ super(WikiText2, self).__init__(root, seq_len, bos, eos, pad)
+
+
+class WikiText103(_WikiText):
+ """WikiText-103 word-level dataset for language modeling, from Salesforce research.
+
+ From
+ https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset
+
+ License: Creative Commons Attribution-ShareAlike
+
+ Parameters
+ ----------
+ root : str, default '~/.mxnet/datasets/wikitext-103'
+ Path to temp folder for storing data.
+ segment : str, default 'train'
+ Dataset segment. Options are 'train', 'val', 'test'.
+ seq_len : int or None, default 35
+ The number of tokens for each sample. If specified, samples are collated by length.
+ If None, each sample is of variable length.
+ bos : str or None, default None
+ The token to add at the begining of each sentence. If None, nothing is added.
+ eos : str or None, default '<eos>'
+ The token to add at the end of each sentence. If None, nothing is added.
+ pad : str or None, default '<pad>'
+ The padding token to add at the end of dataset if `seq_len` is specified and the total
+ number of tokens in the corpus don't evenly divide `seq_len`. If pad is None or seq_len
+ is None, no padding is added. Otherwise, padding token is added to the last sample if
+ its length is less than `seq_len`. If `pad` is None and `seq_len` is specified, the last
+ sample is discarded if it's shorter than `seq_len`.
+ """
+ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'),
+ segment='train', seq_len=35, bos=None, eos=C.EOS_TOKEN, pad=C.PAD_TOKEN):
+ self._archive_file = ('wikitext-103-v1.zip', '0aec09a7537b58d4bb65362fee27650eeaba625a')
+ self._data_file = {'train': ('wiki.train.tokens',
+ 'b7497e2dfe77e72cfef5e3dbc61b7b53712ac211'),
+ 'val': ('wiki.valid.tokens',
+ 'c326ac59dc587676d58c422eb8a03e119582f92b'),
+ 'test': ('wiki.test.tokens',
+ '8a5befc548865cec54ed4273cf87dbbad60d1e47')}
+ self._namespace = 'wikitext-103'
+ self._segment = segment
+ super(WikiText103, self).__init__(root, seq_len, bos, eos, pad)
diff --git a/python/mxnet/gluon/data/text/utils.py b/python/mxnet/gluon/data/text/utils.py
new file mode 100644
index 0000000..68fa519
--- /dev/null
+++ b/python/mxnet/gluon/data/text/utils.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=redefined-builtin
+
+"""Utility functions."""
+
+import sys
+if sys.version_info[0] < 3:
+ from itertools import izip as zip
+
+def flatten_samples(samples):
+ """Flatten list of list of tokens into a single flattened list of tokens.
+
+ Parameters
+ ----------
+ samples : list of list of object
+ List of samples, each of which is a list of tokens.
+
+ Returns
+ -------
+ Flattened list of tokens.
+ """
+ return [token for sample in samples for token in sample if token]
+
+def collate(flat_sample, seq_len):
+ """Collate a flat list of tokens into list of list of tokens, with each
+ inner list's length equal to the specified `seq_len`.
+
+ Parameters
+ ----------
+ flat_sample : list of object
+ A flat list of tokens.
+ seq_len : int
+ The length of each of the samples.
+
+ Returns
+ -------
+ List of samples, each of which has length equal to `seq_len`.
+ """
+ num_samples = len(flat_sample) // seq_len
+ return [flat_sample[i*seq_len:(i+1)*seq_len] for i in range(num_samples)]
+
+def pair(sample):
+ """Produce tuples of tokens from a list of tokens, with current token as the first
+ element and the next token as the second element.
+
+ Parameters
+ ----------
+ sample : list of object
+ A list of tokens.
+
+ Returns
+ -------
+ Generator of tuples, each of which has current token as the first element and the next token
+ as the second element.
+ """
+ return list(zip(sample[:-1], sample[1:]))
diff --git a/tests/python/unittest/test_gluon_data_text.py b/tests/python/unittest/test_gluon_data_text.py
new file mode 100644
index 0000000..a7c083f
--- /dev/null
+++ b/tests/python/unittest/test_gluon_data_text.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import print_function
+import collections
+import mxnet as mx
+from mxnet.gluon import nn, data
+from common import setup_module, with_seed
+
+def get_frequencies(dataset):
+ return collections.Counter(x for tup in dataset for x in tup[0]+tup[1][-1:])
+
+def test_wikitext2():
+ train = data.text.lm.WikiText2(root='data/wikitext-2', segment='train')
+ val = data.text.lm.WikiText2(root='data/wikitext-2', segment='val')
+ test = data.text.lm.WikiText2(root='data/wikitext-2', segment='test')
+ train_freq, val_freq, test_freq = [get_frequencies(x) for x in [train, val, test]]
+ assert len(train) == 58626
+ assert len(train_freq) == 33278
+ assert len(val) == 6112
+ assert len(val_freq) == 13778
+ assert len(test) == 6892
+ assert len(test_freq) == 14144
+ assert test_freq['English'] == 35
+ assert len(train[0][0]) == 35
+ test_no_pad = data.text.lm.WikiText2(root='data/wikitext-2', segment='test', pad=None)
+ assert len(test_no_pad) == 6891
+
+ train_paragraphs = data.text.lm.WikiText2(root='data/wikitext-2', segment='train', seq_len=None)
+ assert len(train_paragraphs) == 23767
+ assert len(train_paragraphs[0][0]) != 35
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()
--
To stop receiving notification emails like this one, please contact
sxjscience@apache.org.