You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/30 20:31:20 UTC
[incubator-mxnet] branch master updated: new NER example: MXNET-321
(#10514)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 12709ac new NER example: MXNET-321 (#10514)
12709ac is described below
commit 12709aca486fc60c28ea968934de95f8d3d0b0b9
Author: Oliver Pringle <oj...@gmail.com>
AuthorDate: Sat Jun 30 13:31:10 2018 -0700
new NER example: MXNET-321 (#10514)
* new NER example
* removing images
* updating readme to generate preprocessed data
* instructions to download and preprocess training data
---
example/named_entity_recognition/README.md | 19 ++
example/named_entity_recognition/src/iterators.py | 175 +++++++++++++++
example/named_entity_recognition/src/metrics.py | 79 +++++++
example/named_entity_recognition/src/ner.py | 236 +++++++++++++++++++++
example/named_entity_recognition/src/preprocess.py | 50 +++++
5 files changed, 559 insertions(+)
diff --git a/example/named_entity_recognition/README.md b/example/named_entity_recognition/README.md
new file mode 100644
index 0000000..260c19d
--- /dev/null
+++ b/example/named_entity_recognition/README.md
@@ -0,0 +1,19 @@
+## Goal
+
+- This repo contains an MXNet implementation of this state of the art [entity recognition model](https://www.aclweb.org/anthology/Q16-1026).
+- You can find my blog post on the model [here](https://opringle.github.io/2018/02/06/CNNLSTM_entity_recognition.html).
+
+![](https://github.com/dmlc/web-data/blob/master/mxnet/example/ner/arch1.png?raw=true)
+
+## Running the code
+
+To reproduce the preprocessed training data:
+
+1. Download and unzip the data: https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus/downloads/ner_dataset.csv
+2. Move ner_dataset.csv into `./data`
+3. create `./preprocessed_data` directory
+3. `$ cd src && python preprocess.py`
+
+To train the model:
+
+- `$ cd src && python ner.py`
\ No newline at end of file
diff --git a/example/named_entity_recognition/src/iterators.py b/example/named_entity_recognition/src/iterators.py
new file mode 100644
index 0000000..a11c570
--- /dev/null
+++ b/example/named_entity_recognition/src/iterators.py
@@ -0,0 +1,175 @@
+# !/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# -*- coding: utf-8 -*-
+
+import bisect
+import random
+import numpy as np
+from mxnet.io import DataIter, DataBatch, DataDesc
+from mxnet import ndarray
+from sklearn.utils import shuffle
+
+class BucketNerIter(DataIter):
+ """
+ This iterator can handle variable length feature/label arrays for MXNet RNN classifiers.
+ This iterator can ingest 2d list of sentences, 2d list of entities and 3d list of characters.
+ """
+
+ def __init__(self, sentences, characters, label, max_token_chars, batch_size, buckets=None, data_pad=-1, label_pad = -1, data_names=['sentences', 'characters'],
+ label_name='seq_label', dtype='float32'):
+
+ super(BucketNerIter, self).__init__()
+
+ # Create a bucket for every seq length where there are more examples than the batch size
+ if not buckets:
+ seq_counts = np.bincount([len(s) for s in sentences])
+ buckets = [i for i, j in enumerate(seq_counts) if j >= batch_size]
+ buckets.sort()
+ print("\nBuckets created: ", buckets)
+ assert(len(buckets) > 0), "Not enough utterances to create any buckets."
+
+ ###########
+ # Sentences
+ ###########
+ nslice = 0
+ # Create empty nested lists for storing data that falls into each bucket
+ self.sentences = [[] for _ in buckets]
+ for i, sent in enumerate(sentences):
+ # Find the index of the smallest bucket that is larger than the sentence length
+ buck_idx = bisect.bisect_left(buckets, len(sent))
+
+ if buck_idx == len(buckets): # If the sentence is larger than the largest bucket
+ buck_idx = buck_idx - 1
+ nslice += 1
+ sent = sent[:buckets[buck_idx]] #Slice sentence to largest bucket size
+
+ buff = np.full((buckets[buck_idx]), data_pad, dtype=dtype) # Create an array filled with 'data_pad'
+ buff[:len(sent)] = sent # Fill with actual values
+ self.sentences[buck_idx].append(buff) # Append array to index = bucket index
+ self.sentences = [np.asarray(i, dtype=dtype) for i in self.sentences] # Convert to list of array
+ print("Warning, {0} sentences sliced to largest bucket size.".format(nslice)) if nslice > 0 else None
+
+ ############
+ # Characters
+ ############
+ # Create empty nested lists for storing data that falls into each bucket
+ self.characters = [[] for _ in buckets]
+ for i, charsent in enumerate(characters):
+ # Find the index of the smallest bucket that is larger than the sentence length
+ buck_idx = bisect.bisect_left(buckets, len(charsent))
+
+ if buck_idx == len(buckets): # If the sentence is larger than the largest bucket
+ buck_idx = buck_idx - 1
+ charsent = charsent[:buckets[buck_idx]] #Slice sentence to largest bucket size
+
+ charsent = [word[:max_token_chars]for word in charsent] # Slice to max length
+ charsent = [word + [data_pad]*(max_token_chars-len(word)) for word in charsent]# Pad to max length
+ charsent = np.array(charsent)
+ buff = np.full((buckets[buck_idx], max_token_chars), data_pad, dtype=dtype)
+ buff[:charsent.shape[0], :] = charsent # Fill with actual values
+ self.characters[buck_idx].append(buff) # Append array to index = bucket index
+ self.characters = [np.asarray(i, dtype=dtype) for i in self.characters] # Convert to list of array
+
+ ##########
+ # Entities
+ ##########
+ # Create empty nested lists for storing data that falls into each bucket
+ self.label = [[] for _ in buckets]
+ self.indices = [[] for _ in buckets]
+ for i, entities in enumerate(label):
+ # Find the index of the smallest bucket that is larger than the sentence length
+ buck_idx = bisect.bisect_left(buckets, len(entities))
+
+ if buck_idx == len(buckets): # If the sentence is larger than the largest bucket
+ buck_idx = buck_idx - 1
+ entities = entities[:buckets[buck_idx]] # Slice sentence to largest bucket size
+
+ buff = np.full((buckets[buck_idx]), label_pad, dtype=dtype) # Create an array filled with 'data_pad'
+ buff[:len(entities)] = entities # Fill with actual values
+ self.label[buck_idx].append(buff) # Append array to index = bucket index
+ self.indices[buck_idx].append(i)
+ self.label = [np.asarray(i, dtype=dtype) for i in self.label] # Convert to list of array
+ self.indices = [np.asarray(i, dtype=dtype) for i in self.indices] # Convert to list of array
+
+ self.data_names = data_names
+ self.label_name = label_name
+ self.batch_size = batch_size
+ self.max_token_chars = max_token_chars
+ self.buckets = buckets
+ self.dtype = dtype
+ self.data_pad = data_pad
+ self.label_pad = label_pad
+ self.default_bucket_key = max(buckets)
+ self.layout = 'NT'
+
+ self.provide_data = [DataDesc(name=self.data_names[0], shape=(self.batch_size, self.default_bucket_key), layout=self.layout),
+ DataDesc(name=self.data_names[1], shape=(self.batch_size, self.default_bucket_key, self.max_token_chars), layout=self.layout)]
+ self.provide_label=[DataDesc(name=self.label_name, shape=(self.batch_size, self.default_bucket_key), layout=self.layout)]
+
+ #create empty list to store batch index values
+ self.idx = []
+ #for each bucketarray
+ for i, buck in enumerate(self.sentences):
+ #extend the list eg output with batch size 5 and 20 training examples in bucket. [(0,0), (0,5), (0,10), (0,15), (1,0), (1,5), (1,10), (1,15)]
+ self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
+ self.curr_idx = 0
+ self.reset()
+
+ def reset(self):
+ """Resets the iterator to the beginning of the data."""
+ self.curr_idx = 0
+ #shuffle data in each bucket
+ random.shuffle(self.idx)
+ for i, buck in enumerate(self.sentences):
+ self.indices[i], self.sentences[i], self.characters[i], self.label[i] = shuffle(self.indices[i],
+ self.sentences[i],
+ self.characters[i],
+ self.label[i])
+
+ self.ndindex = []
+ self.ndsent = []
+ self.ndchar = []
+ self.ndlabel = []
+
+ #for each bucket of data
+ for i, buck in enumerate(self.sentences):
+ #append the lists with an array
+ self.ndindex.append(ndarray.array(self.indices[i], dtype=self.dtype))
+ self.ndsent.append(ndarray.array(self.sentences[i], dtype=self.dtype))
+ self.ndchar.append(ndarray.array(self.characters[i], dtype=self.dtype))
+ self.ndlabel.append(ndarray.array(self.label[i], dtype=self.dtype))
+
+ def next(self):
+ """Returns the next batch of data."""
+ if self.curr_idx == len(self.idx):
+ raise StopIteration
+ #i = batches index, j = starting record
+ i, j = self.idx[self.curr_idx]
+ self.curr_idx += 1
+
+ indices = self.ndindex[i][j:j + self.batch_size]
+ sentences = self.ndsent[i][j:j + self.batch_size]
+ characters = self.ndchar[i][j:j + self.batch_size]
+ label = self.ndlabel[i][j:j + self.batch_size]
+
+ return DataBatch([sentences, characters], [label], pad=0, index = indices, bucket_key=self.buckets[i],
+ provide_data=[DataDesc(name=self.data_names[0], shape=sentences.shape, layout=self.layout),
+ DataDesc(name=self.data_names[1], shape=characters.shape, layout=self.layout)],
+ provide_label=[DataDesc(name=self.label_name, shape=label.shape, layout=self.layout)])
\ No newline at end of file
diff --git a/example/named_entity_recognition/src/metrics.py b/example/named_entity_recognition/src/metrics.py
new file mode 100644
index 0000000..40c5015
--- /dev/null
+++ b/example/named_entity_recognition/src/metrics.py
@@ -0,0 +1,79 @@
+# !/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# -*- coding: utf-8 -*-
+
+import mxnet as mx
+import numpy as np
+import pickle
+
+def load_obj(name):
+ with open(name + '.pkl', 'rb') as f:
+ return pickle.load(f)
+
+tag_dict = load_obj("../preprocessed_data/tag_to_index")
+not_entity_index = tag_dict["O"]
+
+def classifer_metrics(label, pred):
+ """
+ computes f1, precision and recall on the entity class
+ """
+ prediction = np.argmax(pred, axis=1)
+ label = label.astype(int)
+
+ pred_is_entity = prediction != not_entity_index
+ label_is_entity = label != not_entity_index
+
+ corr_pred = (prediction == label) == (pred_is_entity == True)
+
+ #how many entities are there?
+ num_entities = np.sum(label_is_entity)
+ entity_preds = np.sum(pred_is_entity)
+
+ #how many times did we correctly predict an entity?
+ correct_entitites = np.sum(corr_pred[pred_is_entity])
+
+ #precision: when we predict entity, how often are we right?
+ precision = correct_entitites/entity_preds
+ if entity_preds == 0:
+ precision = np.nan
+
+ #recall: of the things that were an entity, how many did we catch?
+ recall = correct_entitites / num_entities
+ if num_entities == 0:
+ recall = np.nan
+ f1 = 2 * precision * recall / (precision + recall)
+ return precision, recall, f1
+
+def entity_precision(label, pred):
+ return classifer_metrics(label, pred)[0]
+
+def entity_recall(label, pred):
+ return classifer_metrics(label, pred)[1]
+
+def entity_f1(label, pred):
+ return classifer_metrics(label, pred)[2]
+
+def composite_classifier_metrics():
+ metric1 = mx.metric.CustomMetric(feval=entity_precision, name='entity precision')
+ metric2 = mx.metric.CustomMetric(feval=entity_recall, name='entity recall')
+ metric3 = mx.metric.CustomMetric(feval=entity_f1, name='entity f1 score')
+ metric4 = mx.metric.Accuracy()
+
+ return mx.metric.CompositeEvalMetric([metric4, metric1, metric2, metric3])
diff --git a/example/named_entity_recognition/src/ner.py b/example/named_entity_recognition/src/ner.py
new file mode 100644
index 0000000..561db4c
--- /dev/null
+++ b/example/named_entity_recognition/src/ner.py
@@ -0,0 +1,236 @@
+# !/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# -*- coding: utf-8 -*-
+
+from collections import Counter
+import itertools
+import iterators
+import os
+import numpy as np
+import pandas as pd
+import mxnet as mx
+import argparse
+import pickle
+import logging
+
+logging.basicConfig(level=logging.DEBUG)
+
+parser = argparse.ArgumentParser(description="Deep neural network for multivariate time series forecasting",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--data-dir', type=str, default='../preprocessed_data',
+ help='relative path to input data')
+parser.add_argument('--output-dir', type=str, default='../results',
+ help='directory to save model files to')
+parser.add_argument('--max-records', type=int, default=None,
+ help='total records before data split')
+parser.add_argument('--train_fraction', type=float, default=0.8,
+ help='fraction of data to use for training. remainder used for testing.')
+parser.add_argument('--batch-size', type=int, default=128,
+ help='the batch size.')
+parser.add_argument('--buckets', type=str, default="",
+ help='unique bucket sizes')
+parser.add_argument('--char-embed', type=int, default=25,
+ help='Embedding size for each unique character.')
+parser.add_argument('--char-filter-list', type=str, default="3,4,5",
+ help='unique filter sizes for char level cnn')
+parser.add_argument('--char-filters', type=int, default=20,
+ help='number of each filter size')
+parser.add_argument('--word-embed', type=int, default=500,
+ help='Embedding size for each unique character.')
+parser.add_argument('--word-filter-list', type=str, default="3,4,5",
+ help='unique filter sizes for char level cnn')
+parser.add_argument('--word-filters', type=int, default=200,
+ help='number of each filter size')
+parser.add_argument('--lstm-state-size', type=int, default=100,
+ help='number of hidden units in each unrolled recurrent cell')
+parser.add_argument('--lstm-layers', type=int, default=1,
+ help='number of recurrent layers')
+parser.add_argument('--gpus', type=str, default='',
+ help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. ')
+parser.add_argument('--optimizer', type=str, default='adam',
+ help='the optimizer type')
+parser.add_argument('--lr', type=float, default=0.001,
+ help='initial learning rate')
+parser.add_argument('--dropout', type=float, default=0.2,
+ help='dropout rate for network')
+parser.add_argument('--num-epochs', type=int, default=100,
+ help='max num of epochs')
+parser.add_argument('--save-period', type=int, default=20,
+ help='save checkpoint for every n epochs')
+parser.add_argument('--model_prefix', type=str, default='electricity_model',
+ help='prefix for saving model params')
+
+def save_obj(obj, name):
+ with open(name + '.pkl', 'wb') as f:
+ pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
+
+def save_model():
+ if not os.path.exists(args.output_dir):
+ os.mkdir(args.output_dir)
+ return mx.callback.do_checkpoint(os.path.join(args.output_dir, "checkpoint"), args.save_period)
+
+def build_vocab(nested_list):
+ """
+ :param nested_list: list of list of string
+ :return: dictionary mapping from string to int, inverse of that dictionary
+ """
+ # Build vocabulary
+ word_counts = Counter(itertools.chain(*nested_list))
+
+ # Mapping from index to label
+ vocabulary_inv = [x[0] for x in word_counts.most_common()]
+
+ # Mapping from label to index
+ vocabulary = {x: i for i, x in enumerate(vocabulary_inv)}
+ return vocabulary, vocabulary_inv
+
+def build_iters(data_dir, max_records, train_fraction, batch_size, buckets=None):
+ """
+ Reads a csv of sentences/tag sequences into a pandas dataframe.
+ Converts into X = array(list(int)) & Y = array(list(int))
+ Splits into training and test sets
+ Builds dictionaries mapping from index labels to labels/ indexed features to features
+ :param data_dir: directory to read in csv data from
+ :param max_records: total number of records to randomly select from input data
+ :param train_fraction: fraction of the data to use for training
+ :param batch_size: records in mini-batches during training
+ :param buckets: size of each bucket in the iterators
+ :return: train_iter, val_iter, word_to_index, index_to_word, pos_to_index, index_to_pos
+ """
+ # Read in data as numpy array
+ df = pd.read_pickle(os.path.join(data_dir, "ner_data.pkl"))[:max_records]
+
+ # Get feature lists
+ entities=[list(array) for array in df["BILOU_tag"].values]
+ sentences = [list(array) for array in df["token"].values]
+ chars=[[[c for c in word] for word in sentence] for sentence in sentences]
+
+ # Build vocabularies
+ entity_to_index, index_to_entity = build_vocab(entities)
+ word_to_index, index_to_word = build_vocab(sentences)
+ char_to_index, index_to_char = build_vocab([np.array([c for c in word]) for word in index_to_word])
+ save_obj(entity_to_index, os.path.join(args.data_dir, "tag_to_index"))
+
+ # Map strings to integer values
+ indexed_entities=[list(map(entity_to_index.get, l)) for l in entities]
+ indexed_tokens=[list(map(word_to_index.get, l)) for l in sentences]
+ indexed_chars=[[list(map(char_to_index.get, word)) for word in sentence] for sentence in chars]
+
+ # Split into training and testing data
+ idx=int(len(indexed_tokens)*train_fraction)
+ X_token_train, X_char_train, Y_train = indexed_tokens[:idx], indexed_chars[:idx], indexed_entities[:idx]
+ X_token_test, X_char_test, Y_test = indexed_tokens[idx:], indexed_chars[idx:], indexed_entities[idx:]
+
+ # build iterators to feed batches to network
+ train_iter = iterators.BucketNerIter(sentences=X_token_train, characters=X_char_train, label=Y_train,
+ max_token_chars=5, batch_size=batch_size, buckets=buckets)
+ val_iter = iterators.BucketNerIter(sentences=X_token_test, characters=X_char_test, label=Y_test,
+ max_token_chars=train_iter.max_token_chars, batch_size=batch_size, buckets=train_iter.buckets)
+ return train_iter, val_iter, word_to_index, char_to_index, entity_to_index
+
+def sym_gen(seq_len):
+ """
+ Build NN symbol depending on the length of the input sequence
+ """
+ sentence_shape = train_iter.provide_data[0][1]
+ char_sentence_shape = train_iter.provide_data[1][1]
+ entities_shape = train_iter.provide_label[0][1]
+
+ X_sent = mx.symbol.Variable(train_iter.provide_data[0].name)
+ X_char_sent = mx.symbol.Variable(train_iter.provide_data[1].name)
+ Y = mx.sym.Variable(train_iter.provide_label[0].name)
+
+ ###############################
+ # Character embedding component
+ ###############################
+ char_embeddings = mx.sym.Embedding(data=X_char_sent, input_dim=len(char_to_index), output_dim=args.char_embed, name='char_embed')
+ char_embeddings = mx.sym.reshape(data=char_embeddings, shape=(0,1,seq_len,-1,args.char_embed), name='char_embed2')
+
+ char_cnn_outputs = []
+ for i, filter_size in enumerate(args.char_filter_list):
+ # Kernel that slides over entire words resulting in a 1d output
+ convi = mx.sym.Convolution(data=char_embeddings, kernel=(1, filter_size, args.char_embed), stride=(1, 1, 1),
+ num_filter=args.char_filters, name="char_conv_layer_" + str(i))
+ acti = mx.sym.Activation(data=convi, act_type='tanh')
+ pooli = mx.sym.Pooling(data=acti, pool_type='max', kernel=(1, char_sentence_shape[2] - filter_size + 1, 1),
+ stride=(1, 1, 1), name="char_pool_layer_" + str(i))
+ pooli = mx.sym.transpose(mx.sym.Reshape(pooli, shape=(0, 0, 0)), axes=(0, 2, 1), name="cchar_conv_layer_" + str(i))
+ char_cnn_outputs.append(pooli)
+
+ # combine features from all filters & apply dropout
+ cnn_char_features = mx.sym.Concat(*char_cnn_outputs, dim=2, name="cnn_char_features")
+ regularized_cnn_char_features = mx.sym.Dropout(data=cnn_char_features, p=args.dropout, mode='training',
+ name='regularized charCnn features')
+
+ ##################################
+ # Combine char and word embeddings
+ ##################################
+ word_embeddings = mx.sym.Embedding(data=X_sent, input_dim=len(word_to_index), output_dim=args.word_embed, name='word_embed')
+ rnn_features = mx.sym.Concat(*[word_embeddings, regularized_cnn_char_features], dim=2, name='rnn input')
+
+ ##############################
+ # Bidirectional LSTM component
+ ##############################
+
+ # unroll the lstm cell in time, merging outputs
+ bi_cell.reset()
+ output, states = bi_cell.unroll(length=seq_len, inputs=rnn_features, merge_outputs=True)
+
+ # Map to num entity classes
+ rnn_output = mx.sym.Reshape(output, shape=(-1, args.lstm_state_size * 2), name='r_output')
+ fc = mx.sym.FullyConnected(data=rnn_output, num_hidden=len(entity_to_index), name='fc_layer')
+
+ # reshape back to same shape as loss will be
+ reshaped_fc = mx.sym.transpose(mx.sym.reshape(fc, shape=(-1, seq_len, len(entity_to_index))), axes=(0, 2, 1))
+ sm = mx.sym.SoftmaxOutput(data=reshaped_fc, label=Y, ignore_label=-1, use_ignore=True, multi_output=True, name='softmax')
+ return sm, [v.name for v in train_iter.provide_data], [v.name for v in train_iter.provide_label]
+
+def train(train_iter, val_iter):
+ import metrics
+ devs = mx.cpu() if args.gpus is None or args.gpus is '' else [mx.gpu(int(i)) for i in args.gpus.split(',')]
+ module = mx.mod.BucketingModule(sym_gen, train_iter.default_bucket_key, context=devs)
+ module.fit(train_data=train_iter,
+ eval_data=val_iter,
+ eval_metric=metrics.composite_classifier_metrics(),
+ optimizer=args.optimizer,
+ optimizer_params={'learning_rate': args.lr },
+ initializer=mx.initializer.Uniform(0.1),
+ num_epoch=args.num_epochs,
+ epoch_end_callback=save_model())
+
+if __name__ == '__main__':
+ # parse args
+ args = parser.parse_args()
+ args.buckets = list(map(int, args.buckets.split(','))) if len(args.buckets) > 0 else None
+ args.char_filter_list = list(map(int, args.char_filter_list.split(',')))
+
+ # Build data iterators
+ train_iter, val_iter, word_to_index, char_to_index, entity_to_index = build_iters(args.data_dir, args.max_records,
+ args.train_fraction, args.batch_size, args.buckets)
+
+ # Define the recurrent layer
+ bi_cell = mx.rnn.SequentialRNNCell()
+ for layer_num in range(args.lstm_layers):
+ bi_cell.add(mx.rnn.BidirectionalCell(
+ mx.rnn.LSTMCell(num_hidden=args.lstm_state_size, prefix="forward_layer_" + str(layer_num)),
+ mx.rnn.LSTMCell(num_hidden=args.lstm_state_size, prefix="backward_layer_" + str(layer_num))))
+ bi_cell.add(mx.rnn.DropoutCell(args.dropout))
+
+ train(train_iter, val_iter)
\ No newline at end of file
diff --git a/example/named_entity_recognition/src/preprocess.py b/example/named_entity_recognition/src/preprocess.py
new file mode 100644
index 0000000..6ae348a
--- /dev/null
+++ b/example/named_entity_recognition/src/preprocess.py
@@ -0,0 +1,50 @@
+# !/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# -*- coding: utf-8 -*-
+
+import pandas as pd
+import numpy as np
+
+#read in csv of NER training data
+df = pd.read_csv("../data/ner_dataset.csv", encoding="ISO-8859-1")
+
+#rename columns
+df = df.rename(columns = {"Sentence #" : "utterance_id",
+ "Word" : "token",
+ "POS" : "POS_tag",
+ "Tag" : "BILOU_tag"})
+
+#clean utterance_id column
+df.loc[:, "utterance_id"] = df["utterance_id"].str.replace('Sentence: ', '')
+
+#fill np.nan utterance ID's with the last valid entry
+df = df.fillna(method='ffill')
+df.loc[:, "utterance_id"] = df["utterance_id"].apply(int)
+
+#melt BILOU tags and tokens into an array per utterance
+df1 = df.groupby("utterance_id")["BILOU_tag"].apply(lambda x: np.array(x)).to_frame().reset_index()
+df2 = df.groupby("utterance_id")["token"].apply(lambda x: np.array(x)).to_frame().reset_index()
+df3 = df.groupby("utterance_id")["POS_tag"].apply(lambda x: np.array(x)).to_frame().reset_index()
+
+#join the results on utterance id
+df = df1.merge(df2.merge(df3, how = "left", on = "utterance_id"), how = "left", on = "utterance_id")
+
+#save the dataframe to a csv file
+df.to_pickle("../data/ner_data.pkl")
\ No newline at end of file