You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/30 20:31:20 UTC

[incubator-mxnet] branch master updated: new NER example: MXNET-321 (#10514)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 12709ac  new NER example: MXNET-321 (#10514)
12709ac is described below

commit 12709aca486fc60c28ea968934de95f8d3d0b0b9
Author: Oliver Pringle <oj...@gmail.com>
AuthorDate: Sat Jun 30 13:31:10 2018 -0700

    new NER example: MXNET-321 (#10514)
    
    * new NER example
    
    * removing images
    
    * updating readme to generate preprocessed data
    
    * instructions to download and preprocess training data
---
 example/named_entity_recognition/README.md         |  19 ++
 example/named_entity_recognition/src/iterators.py  | 175 +++++++++++++++
 example/named_entity_recognition/src/metrics.py    |  79 +++++++
 example/named_entity_recognition/src/ner.py        | 236 +++++++++++++++++++++
 example/named_entity_recognition/src/preprocess.py |  50 +++++
 5 files changed, 559 insertions(+)

diff --git a/example/named_entity_recognition/README.md b/example/named_entity_recognition/README.md
new file mode 100644
index 0000000..260c19d
--- /dev/null
+++ b/example/named_entity_recognition/README.md
@@ -0,0 +1,19 @@
+## Goal
+
+- This repo contains an MXNet implementation of this state of the art [entity recognition model](https://www.aclweb.org/anthology/Q16-1026).
+- You can find my blog post on the model [here](https://opringle.github.io/2018/02/06/CNNLSTM_entity_recognition.html).
+
+![](https://github.com/dmlc/web-data/blob/master/mxnet/example/ner/arch1.png?raw=true)
+
+## Running the code
+
+To reproduce the preprocessed training data:
+
+1. Download and unzip the data: https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus/downloads/ner_dataset.csv
+2. Move ner_dataset.csv into `./data`
+3. create `./preprocessed_data` directory
+3. `$ cd src && python preprocess.py`
+
+To train the model:
+
+- `$ cd src && python ner.py`
\ No newline at end of file
diff --git a/example/named_entity_recognition/src/iterators.py b/example/named_entity_recognition/src/iterators.py
new file mode 100644
index 0000000..a11c570
--- /dev/null
+++ b/example/named_entity_recognition/src/iterators.py
@@ -0,0 +1,175 @@
+# !/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# -*- coding: utf-8 -*-
+
+import bisect
+import random
+import numpy as np
+from mxnet.io import DataIter, DataBatch, DataDesc
+from mxnet import ndarray
+from sklearn.utils import shuffle
+
+class BucketNerIter(DataIter):
+    """
+    This iterator can handle variable length feature/label arrays for MXNet RNN classifiers.
+    This iterator can ingest 2d list of sentences, 2d list of entities and 3d list of characters.
+    """
+
+    def __init__(self, sentences, characters, label, max_token_chars, batch_size, buckets=None, data_pad=-1, label_pad = -1, data_names=['sentences', 'characters'],
+                 label_name='seq_label', dtype='float32'):
+
+        super(BucketNerIter, self).__init__()
+
+        # Create a bucket for every seq length where there are more examples than the batch size
+        if not buckets:
+            seq_counts = np.bincount([len(s) for s in sentences])
+            buckets = [i for i, j in enumerate(seq_counts) if j >= batch_size]
+        buckets.sort()
+        print("\nBuckets  created: ", buckets)
+        assert(len(buckets) > 0), "Not enough utterances to create any buckets."
+
+        ###########
+        # Sentences
+        ###########
+        nslice = 0
+        # Create empty nested lists for storing data that falls into each bucket
+        self.sentences = [[] for _ in buckets]
+        for i, sent in enumerate(sentences):
+            # Find the index of the smallest bucket that is larger than the sentence length
+            buck_idx = bisect.bisect_left(buckets, len(sent))
+
+            if buck_idx == len(buckets): # If the sentence is larger than the largest bucket
+                buck_idx = buck_idx - 1
+                nslice += 1
+                sent = sent[:buckets[buck_idx]] #Slice sentence to largest bucket size
+
+            buff = np.full((buckets[buck_idx]), data_pad, dtype=dtype) # Create an array filled with 'data_pad'
+            buff[:len(sent)] = sent # Fill with actual values
+            self.sentences[buck_idx].append(buff) # Append array to index = bucket index
+        self.sentences = [np.asarray(i, dtype=dtype) for i in self.sentences] # Convert to list of array
+        print("Warning, {0} sentences sliced to largest bucket size.".format(nslice)) if nslice > 0 else None
+
+        ############
+        # Characters
+        ############
+        # Create empty nested lists for storing data that falls into each bucket
+        self.characters = [[] for _ in buckets]
+        for i, charsent in enumerate(characters):
+            # Find the index of the smallest bucket that is larger than the sentence length
+            buck_idx = bisect.bisect_left(buckets, len(charsent))
+
+            if buck_idx == len(buckets): # If the sentence is larger than the largest bucket
+                buck_idx = buck_idx - 1
+                charsent = charsent[:buckets[buck_idx]] #Slice sentence to largest bucket size
+
+            charsent = [word[:max_token_chars]for word in charsent] # Slice to max length
+            charsent = [word + [data_pad]*(max_token_chars-len(word)) for word in charsent]# Pad to max length
+            charsent = np.array(charsent)
+            buff = np.full((buckets[buck_idx], max_token_chars), data_pad, dtype=dtype)
+            buff[:charsent.shape[0], :] = charsent # Fill with actual values
+            self.characters[buck_idx].append(buff) # Append array to index = bucket index
+        self.characters = [np.asarray(i, dtype=dtype) for i in self.characters] # Convert to list of array
+
+        ##########
+        # Entities
+        ##########
+        # Create empty nested lists for storing data that falls into each bucket
+        self.label = [[] for _ in buckets]
+        self.indices = [[] for _ in buckets]
+        for i, entities in enumerate(label):
+            # Find the index of the smallest bucket that is larger than the sentence length
+            buck_idx = bisect.bisect_left(buckets, len(entities))
+
+            if buck_idx == len(buckets):  # If the sentence is larger than the largest bucket
+                buck_idx = buck_idx - 1
+                entities = entities[:buckets[buck_idx]]  # Slice sentence to largest bucket size
+
+            buff = np.full((buckets[buck_idx]), label_pad, dtype=dtype)  # Create an array filled with 'data_pad'
+            buff[:len(entities)] = entities  # Fill with actual values
+            self.label[buck_idx].append(buff)  # Append array to index = bucket index
+            self.indices[buck_idx].append(i)
+        self.label = [np.asarray(i, dtype=dtype) for i in self.label]  # Convert to list of array
+        self.indices = [np.asarray(i, dtype=dtype) for i in self.indices]  # Convert to list of array
+
+        self.data_names = data_names
+        self.label_name = label_name
+        self.batch_size = batch_size
+        self.max_token_chars = max_token_chars
+        self.buckets = buckets
+        self.dtype = dtype
+        self.data_pad = data_pad
+        self.label_pad = label_pad
+        self.default_bucket_key = max(buckets)
+        self.layout = 'NT'
+
+        self.provide_data = [DataDesc(name=self.data_names[0], shape=(self.batch_size, self.default_bucket_key), layout=self.layout),
+                             DataDesc(name=self.data_names[1], shape=(self.batch_size, self.default_bucket_key, self.max_token_chars), layout=self.layout)]
+        self.provide_label=[DataDesc(name=self.label_name, shape=(self.batch_size, self.default_bucket_key), layout=self.layout)]
+
+        #create empty list to store batch index values
+        self.idx = []
+        #for each bucketarray
+        for i, buck in enumerate(self.sentences):
+            #extend the list eg output with batch size 5 and 20 training examples in bucket. [(0,0), (0,5), (0,10), (0,15), (1,0), (1,5), (1,10), (1,15)]
+            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
+        self.curr_idx = 0
+        self.reset()
+
+    def reset(self):
+        """Resets the iterator to the beginning of the data."""
+        self.curr_idx = 0
+        #shuffle data in each bucket
+        random.shuffle(self.idx)
+        for i, buck in enumerate(self.sentences):
+            self.indices[i], self.sentences[i], self.characters[i], self.label[i] = shuffle(self.indices[i],
+                                                                                            self.sentences[i],
+                                                                                            self.characters[i],
+                                                                                            self.label[i])
+
+        self.ndindex = []
+        self.ndsent = []
+        self.ndchar = []
+        self.ndlabel = []
+
+        #for each bucket of data
+        for i, buck in enumerate(self.sentences):
+            #append the lists with an array
+            self.ndindex.append(ndarray.array(self.indices[i], dtype=self.dtype))
+            self.ndsent.append(ndarray.array(self.sentences[i], dtype=self.dtype))
+            self.ndchar.append(ndarray.array(self.characters[i], dtype=self.dtype))
+            self.ndlabel.append(ndarray.array(self.label[i], dtype=self.dtype))
+
+    def next(self):
+        """Returns the next batch of data."""
+        if self.curr_idx == len(self.idx):
+            raise StopIteration
+        #i = batches index, j = starting record
+        i, j = self.idx[self.curr_idx] 
+        self.curr_idx += 1
+
+        indices = self.ndindex[i][j:j + self.batch_size]
+        sentences = self.ndsent[i][j:j + self.batch_size]
+        characters = self.ndchar[i][j:j + self.batch_size]
+        label = self.ndlabel[i][j:j + self.batch_size]
+
+        return DataBatch([sentences, characters], [label], pad=0, index = indices, bucket_key=self.buckets[i],
+                         provide_data=[DataDesc(name=self.data_names[0], shape=sentences.shape, layout=self.layout),
+                                       DataDesc(name=self.data_names[1], shape=characters.shape, layout=self.layout)],
+                         provide_label=[DataDesc(name=self.label_name, shape=label.shape, layout=self.layout)])
\ No newline at end of file
diff --git a/example/named_entity_recognition/src/metrics.py b/example/named_entity_recognition/src/metrics.py
new file mode 100644
index 0000000..40c5015
--- /dev/null
+++ b/example/named_entity_recognition/src/metrics.py
@@ -0,0 +1,79 @@
+# !/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# -*- coding: utf-8 -*-
+
+import mxnet as mx
+import numpy as np
+import pickle
+
+def load_obj(name):
+    with open(name + '.pkl', 'rb') as f:
+        return pickle.load(f)
+
+tag_dict = load_obj("../preprocessed_data/tag_to_index")
+not_entity_index = tag_dict["O"]
+
+def classifer_metrics(label, pred):
+    """
+    computes f1, precision and recall on the entity class
+    """
+    prediction = np.argmax(pred, axis=1)
+    label = label.astype(int)
+
+    pred_is_entity = prediction != not_entity_index
+    label_is_entity = label != not_entity_index
+
+    corr_pred = (prediction == label) == (pred_is_entity == True)
+
+    #how many entities are there?
+    num_entities = np.sum(label_is_entity)
+    entity_preds = np.sum(pred_is_entity)
+
+    #how many times did we correctly predict an entity?
+    correct_entitites = np.sum(corr_pred[pred_is_entity])
+
+    #precision: when we predict entity, how often are we right?
+    precision = correct_entitites/entity_preds
+    if entity_preds == 0:
+        precision = np.nan
+
+    #recall: of the things that were an entity, how many did we catch?
+    recall = correct_entitites / num_entities
+    if num_entities == 0:
+        recall = np.nan
+    f1 = 2 * precision * recall / (precision + recall)
+    return precision, recall, f1
+
+def entity_precision(label, pred):
+    return classifer_metrics(label, pred)[0]
+
+def entity_recall(label, pred):
+    return classifer_metrics(label, pred)[1]
+
+def entity_f1(label, pred):
+    return classifer_metrics(label, pred)[2]
+
+def composite_classifier_metrics():
+    metric1 = mx.metric.CustomMetric(feval=entity_precision, name='entity precision')
+    metric2 = mx.metric.CustomMetric(feval=entity_recall, name='entity recall')
+    metric3 = mx.metric.CustomMetric(feval=entity_f1, name='entity f1 score')
+    metric4 = mx.metric.Accuracy()
+
+    return mx.metric.CompositeEvalMetric([metric4, metric1, metric2, metric3])
diff --git a/example/named_entity_recognition/src/ner.py b/example/named_entity_recognition/src/ner.py
new file mode 100644
index 0000000..561db4c
--- /dev/null
+++ b/example/named_entity_recognition/src/ner.py
@@ -0,0 +1,236 @@
+# !/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# -*- coding: utf-8 -*-
+
+from collections import Counter
+import itertools
+import iterators
+import os
+import numpy as np
+import pandas as pd
+import mxnet as mx
+import argparse
+import pickle
+import logging
+
+logging.basicConfig(level=logging.DEBUG)
+
+parser = argparse.ArgumentParser(description="Deep neural network for multivariate time series forecasting",
+                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--data-dir', type=str, default='../preprocessed_data',
+                    help='relative path to input data')
+parser.add_argument('--output-dir', type=str, default='../results',
+                    help='directory to save model files to')
+parser.add_argument('--max-records', type=int, default=None,
+                    help='total records before data split')
+parser.add_argument('--train_fraction', type=float, default=0.8,
+                    help='fraction of data to use for training. remainder used for testing.')
+parser.add_argument('--batch-size', type=int, default=128,
+                    help='the batch size.')
+parser.add_argument('--buckets', type=str, default="",
+                    help='unique bucket sizes')
+parser.add_argument('--char-embed', type=int, default=25,
+                    help='Embedding size for each unique character.')
+parser.add_argument('--char-filter-list', type=str, default="3,4,5",
+                    help='unique filter sizes for char level cnn')
+parser.add_argument('--char-filters', type=int, default=20,
+                    help='number of each filter size')
+parser.add_argument('--word-embed', type=int, default=500,
+                    help='Embedding size for each unique character.')
+parser.add_argument('--word-filter-list', type=str, default="3,4,5",
+                    help='unique filter sizes for char level cnn')
+parser.add_argument('--word-filters', type=int, default=200,
+                    help='number of each filter size')
+parser.add_argument('--lstm-state-size', type=int, default=100,
+                    help='number of hidden units in each unrolled recurrent cell')
+parser.add_argument('--lstm-layers', type=int, default=1,
+                    help='number of recurrent layers')
+parser.add_argument('--gpus', type=str, default='',
+                    help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. ')
+parser.add_argument('--optimizer', type=str, default='adam',
+                    help='the optimizer type')
+parser.add_argument('--lr', type=float, default=0.001,
+                    help='initial learning rate')
+parser.add_argument('--dropout', type=float, default=0.2,
+                    help='dropout rate for network')
+parser.add_argument('--num-epochs', type=int, default=100,
+                    help='max num of epochs')
+parser.add_argument('--save-period', type=int, default=20,
+                    help='save checkpoint for every n epochs')
+parser.add_argument('--model_prefix', type=str, default='electricity_model',
+                    help='prefix for saving model params')
+
+def save_obj(obj, name):
+    with open(name + '.pkl', 'wb') as f:
+        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
+
+def save_model():
+    if not os.path.exists(args.output_dir):
+        os.mkdir(args.output_dir)
+    return mx.callback.do_checkpoint(os.path.join(args.output_dir, "checkpoint"), args.save_period)
+
+def build_vocab(nested_list):
+    """
+    :param nested_list: list of list of string
+    :return: dictionary mapping from string to int, inverse of that dictionary
+    """
+    # Build vocabulary
+    word_counts = Counter(itertools.chain(*nested_list))
+
+    # Mapping from index to label
+    vocabulary_inv = [x[0] for x in word_counts.most_common()]
+
+    # Mapping from label to index
+    vocabulary = {x: i for i, x in enumerate(vocabulary_inv)}
+    return vocabulary, vocabulary_inv
+
+def build_iters(data_dir, max_records, train_fraction, batch_size, buckets=None):
+    """
+    Reads a csv of sentences/tag sequences into a pandas dataframe.
+    Converts into X = array(list(int)) & Y = array(list(int))
+    Splits into training and test sets
+    Builds dictionaries mapping from index labels to labels/ indexed features to features
+    :param data_dir: directory to read in csv data from
+    :param max_records: total number of records to randomly select from input data
+    :param train_fraction: fraction of the data to use for training
+    :param batch_size: records in mini-batches during training
+    :param buckets: size of each bucket in the iterators
+    :return: train_iter, val_iter, word_to_index, index_to_word, pos_to_index, index_to_pos
+    """
+    # Read in data as numpy array
+    df = pd.read_pickle(os.path.join(data_dir, "ner_data.pkl"))[:max_records]
+
+    # Get feature lists
+    entities=[list(array) for array in df["BILOU_tag"].values]
+    sentences = [list(array) for array in df["token"].values]
+    chars=[[[c for c in word] for word in sentence] for sentence in sentences]
+
+    # Build vocabularies
+    entity_to_index, index_to_entity = build_vocab(entities)
+    word_to_index, index_to_word = build_vocab(sentences)
+    char_to_index, index_to_char = build_vocab([np.array([c for c in word]) for word in index_to_word])
+    save_obj(entity_to_index, os.path.join(args.data_dir, "tag_to_index"))
+
+    # Map strings to integer values
+    indexed_entities=[list(map(entity_to_index.get, l)) for l in entities]
+    indexed_tokens=[list(map(word_to_index.get, l)) for l in sentences]
+    indexed_chars=[[list(map(char_to_index.get, word)) for word in sentence] for sentence in chars]
+
+    # Split into training and testing data
+    idx=int(len(indexed_tokens)*train_fraction)
+    X_token_train, X_char_train, Y_train = indexed_tokens[:idx], indexed_chars[:idx], indexed_entities[:idx]
+    X_token_test, X_char_test, Y_test = indexed_tokens[idx:], indexed_chars[idx:], indexed_entities[idx:]
+
+    # build iterators to feed batches to network
+    train_iter = iterators.BucketNerIter(sentences=X_token_train, characters=X_char_train, label=Y_train,
+                                         max_token_chars=5, batch_size=batch_size, buckets=buckets)
+    val_iter = iterators.BucketNerIter(sentences=X_token_test, characters=X_char_test, label=Y_test,
+                                         max_token_chars=train_iter.max_token_chars, batch_size=batch_size, buckets=train_iter.buckets)
+    return train_iter, val_iter, word_to_index, char_to_index, entity_to_index
+
+def sym_gen(seq_len):
+    """
+    Build NN symbol depending on the length of the input sequence
+    """
+    sentence_shape = train_iter.provide_data[0][1]
+    char_sentence_shape = train_iter.provide_data[1][1]
+    entities_shape = train_iter.provide_label[0][1]
+
+    X_sent = mx.symbol.Variable(train_iter.provide_data[0].name)
+    X_char_sent = mx.symbol.Variable(train_iter.provide_data[1].name)
+    Y = mx.sym.Variable(train_iter.provide_label[0].name)
+
+    ###############################
+    # Character embedding component
+    ###############################
+    char_embeddings = mx.sym.Embedding(data=X_char_sent, input_dim=len(char_to_index), output_dim=args.char_embed, name='char_embed')
+    char_embeddings = mx.sym.reshape(data=char_embeddings, shape=(0,1,seq_len,-1,args.char_embed), name='char_embed2')
+
+    char_cnn_outputs = []
+    for i, filter_size in enumerate(args.char_filter_list):
+        # Kernel that slides over entire words resulting in a 1d output
+        convi = mx.sym.Convolution(data=char_embeddings, kernel=(1, filter_size, args.char_embed), stride=(1, 1, 1),
+                                   num_filter=args.char_filters, name="char_conv_layer_" + str(i))
+        acti = mx.sym.Activation(data=convi, act_type='tanh')
+        pooli = mx.sym.Pooling(data=acti, pool_type='max', kernel=(1, char_sentence_shape[2] - filter_size + 1, 1),
+                               stride=(1, 1, 1), name="char_pool_layer_" + str(i))
+        pooli = mx.sym.transpose(mx.sym.Reshape(pooli, shape=(0, 0, 0)), axes=(0, 2, 1), name="cchar_conv_layer_" + str(i))
+        char_cnn_outputs.append(pooli)
+
+    # combine features from all filters & apply dropout
+    cnn_char_features = mx.sym.Concat(*char_cnn_outputs, dim=2, name="cnn_char_features")
+    regularized_cnn_char_features = mx.sym.Dropout(data=cnn_char_features, p=args.dropout, mode='training',
+                                                   name='regularized charCnn features')
+
+    ##################################
+    # Combine char and word embeddings
+    ##################################
+    word_embeddings = mx.sym.Embedding(data=X_sent, input_dim=len(word_to_index), output_dim=args.word_embed, name='word_embed')
+    rnn_features = mx.sym.Concat(*[word_embeddings, regularized_cnn_char_features], dim=2, name='rnn input')
+
+    ##############################
+    # Bidirectional LSTM component
+    ##############################
+
+    # unroll the lstm cell in time, merging outputs
+    bi_cell.reset()
+    output, states = bi_cell.unroll(length=seq_len, inputs=rnn_features, merge_outputs=True)
+
+    # Map to num entity classes
+    rnn_output = mx.sym.Reshape(output, shape=(-1, args.lstm_state_size * 2), name='r_output')
+    fc = mx.sym.FullyConnected(data=rnn_output, num_hidden=len(entity_to_index), name='fc_layer')
+
+    # reshape back to same shape as loss will be
+    reshaped_fc = mx.sym.transpose(mx.sym.reshape(fc, shape=(-1, seq_len, len(entity_to_index))), axes=(0, 2, 1))
+    sm = mx.sym.SoftmaxOutput(data=reshaped_fc, label=Y, ignore_label=-1, use_ignore=True, multi_output=True, name='softmax')
+    return sm, [v.name for v in train_iter.provide_data], [v.name for v in train_iter.provide_label]
+
+def train(train_iter, val_iter):
+    import metrics
+    devs = mx.cpu() if args.gpus is None or args.gpus is '' else [mx.gpu(int(i)) for i in args.gpus.split(',')]
+    module = mx.mod.BucketingModule(sym_gen, train_iter.default_bucket_key, context=devs)
+    module.fit(train_data=train_iter,
+               eval_data=val_iter,
+               eval_metric=metrics.composite_classifier_metrics(),
+               optimizer=args.optimizer,
+               optimizer_params={'learning_rate': args.lr },
+               initializer=mx.initializer.Uniform(0.1),
+               num_epoch=args.num_epochs,
+               epoch_end_callback=save_model())
+
+if __name__ == '__main__':
+    # parse args
+    args = parser.parse_args()
+    args.buckets = list(map(int, args.buckets.split(','))) if len(args.buckets) > 0 else None
+    args.char_filter_list = list(map(int, args.char_filter_list.split(',')))
+
+    # Build data iterators
+    train_iter, val_iter, word_to_index, char_to_index, entity_to_index = build_iters(args.data_dir, args.max_records,
+                                                                     args.train_fraction, args.batch_size, args.buckets)
+
+    # Define the recurrent layer
+    bi_cell = mx.rnn.SequentialRNNCell()
+    for layer_num in range(args.lstm_layers):
+        bi_cell.add(mx.rnn.BidirectionalCell(
+            mx.rnn.LSTMCell(num_hidden=args.lstm_state_size, prefix="forward_layer_" + str(layer_num)),
+            mx.rnn.LSTMCell(num_hidden=args.lstm_state_size, prefix="backward_layer_" + str(layer_num))))
+        bi_cell.add(mx.rnn.DropoutCell(args.dropout))
+
+    train(train_iter, val_iter)
\ No newline at end of file
diff --git a/example/named_entity_recognition/src/preprocess.py b/example/named_entity_recognition/src/preprocess.py
new file mode 100644
index 0000000..6ae348a
--- /dev/null
+++ b/example/named_entity_recognition/src/preprocess.py
@@ -0,0 +1,50 @@
+# !/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# -*- coding: utf-8 -*-
+
+import pandas as pd
+import numpy as np
+
+#read in csv of NER training data
+df = pd.read_csv("../data/ner_dataset.csv", encoding="ISO-8859-1")
+
+#rename columns
+df = df.rename(columns = {"Sentence #" : "utterance_id",
+                            "Word" : "token", 
+                            "POS" : "POS_tag", 
+                            "Tag" : "BILOU_tag"})
+
+#clean utterance_id column
+df.loc[:, "utterance_id"] = df["utterance_id"].str.replace('Sentence: ', '')
+
+#fill np.nan utterance ID's with the last valid entry
+df = df.fillna(method='ffill')
+df.loc[:, "utterance_id"] = df["utterance_id"].apply(int)
+
+#melt BILOU tags and tokens into an array per utterance
+df1 = df.groupby("utterance_id")["BILOU_tag"].apply(lambda x: np.array(x)).to_frame().reset_index()
+df2 = df.groupby("utterance_id")["token"].apply(lambda x: np.array(x)).to_frame().reset_index()
+df3 = df.groupby("utterance_id")["POS_tag"].apply(lambda x: np.array(x)).to_frame().reset_index()
+
+#join the results on utterance id
+df = df1.merge(df2.merge(df3, how = "left", on = "utterance_id"), how = "left", on = "utterance_id")
+
+#save the dataframe to a csv file
+df.to_pickle("../data/ner_data.pkl")
\ No newline at end of file