You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/12 18:52:35 UTC

[GitHub] piiswrong closed pull request #9345: example/ctc improvements

piiswrong closed pull request #9345: example/ctc improvements
URL: https://github.com/apache/incubator-mxnet/pull/9345
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/tutorials/speech_recognition/baidu_warp_ctc.md b/docs/tutorials/speech_recognition/baidu_warp_ctc.md
deleted file mode 100644
index 6277a19bfd..0000000000
--- a/docs/tutorials/speech_recognition/baidu_warp_ctc.md
+++ /dev/null
@@ -1,97 +0,0 @@
-# Using Baidu Warp-CTC with MXNet
-
-
-Baidu-WarpCTC is a CTC implementation by Baidu that supports using GPU processors. It supports using CTC with LSTM to solve label alignment problems in many areas, such as OCR and speech recognition.
-
-You can get the source code for the example on [GitHub](https://github.com/dmlc/mxnet/tree/master/example/warpctc).
-
-## Install Baidu Warp-CTC
-
-```
-  cd ~/
-  git clone https://github.com/baidu-research/warp-ctc
-  cd warp-ctc
-  mkdir build
-  cd build
-  cmake ..
-  make
-  sudo make install
-```
-
-## Enable Warp-CTC in MXNet
-
-```
-  comment out following lines in make/config.mk
-  WARPCTC_PATH = $(HOME)/warp-ctc
-  MXNET_PLUGINS += plugin/warpctc/warpctc.mk
-
-  rebuild mxnet by
-  make clean && make -j4
-```
-
-## Run Examples
-
-There are two examples. One is a toy example that validates CTC integration. The second is an OCR example with LSTM and CTC. You can run it by typing the following code:
-
-```
-  cd examples/warpctc
-  python lstm_ocr.py
-```
-
-The OCR example is constructed as follows:
-
-1. It generates a 80x30-pixel image for a 4-digit captcha using a Python captcha library.
-2. The 80x30 image is used as 80 input for LSTM, and every input is one column of the image (a 30 dim vector).
-3. The output layer use CTC loss.
-
-The following code shows the detailed construction of the net: 
-
-```
-  def lstm_unroll(num_lstm_layer, seq_len,
-                  num_hidden, num_label):
-    param_cells = []
-    last_states = []
-    for i in range(num_lstm_layer):
-        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
-                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
-                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
-                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
-        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
-                          h=mx.sym.Variable("l%d_init_h" % i))
-        last_states.append(state)
-    assert(len(last_states) == num_lstm_layer)
-    data = mx.sym.Variable('data')
-    label = mx.sym.Variable('label')
-
-    #every column of image is an input, there are seq_len inputs
-    wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
-    hidden_all = []
-    for seqidx in range(seq_len):
-        hidden = wordvec[seqidx]
-        for i in range(num_lstm_layer):
-            next_state = lstm(num_hidden, indata=hidden,
-                              prev_state=last_states[i],
-                              param=param_cells[i],
-                              seqidx=seqidx, layeridx=i)
-            hidden = next_state.h
-            last_states[i] = next_state
-        hidden_all.append(hidden)
-    hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
-    pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
-
-    # here we do NOT need to transpose label as other lstm examples do
-    label = mx.sym.Reshape(data=label, target_shape=(0,))
-    #label should be int type, so use cast
-    label = mx.sym.Cast(data = label, dtype = 'int32')
-    sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
-    return sm
-```
-
-## Supporting Multi-label Length
-
-Provide labels with length b. For samples whose label length is smaller than b, append 0 to the label data to make it have length b.
-
-0 is reserved for a blank label.
-
-## Next Steps
-* [MXNet tutorials index](http://mxnet.io/tutorials/index.html)
diff --git a/docs/tutorials/speech_recognition/ctc.md b/docs/tutorials/speech_recognition/ctc.md
new file mode 100644
index 0000000000..9c9a9c98db
--- /dev/null
+++ b/docs/tutorials/speech_recognition/ctc.md
@@ -0,0 +1,15 @@
+# Connectionist Temporal Classification
+
+[Connectionist Temporal Classification](https://www.cs.toronto.edu/~graves/icml_2006.pdf) (CTC) is a cost function that is used to train Recurrent Neural Networks (RNNs) to label unsegmented input sequence data in supervised learning. For example, in a speech recognition application, using a typical cross-entropy loss, the input signal needs to be segmented into words or sub-words. However, using CTC-loss, it suffices to provide one label sequence for input sequence and the network learns both the alignment as well labeling. Baidu's warp-ctc page contains a more detailed [introduction to CTC-loss](https://github.com/baidu-research/warp-ctc#introduction).
+
+## CTC-loss in MXNet
+MXNet supports two CTC-loss layers in Symbol API:
+
+* `mxnet.symbol.contrib.ctc_loss` is implemented in MXNet and included as part of the standard package.
+* `mxnet.symbol.WarpCTC` uses Baidu's warp-ctc library and requires building warp-ctc library and mxnet library both from source.
+
+## LSTM OCR Example
+MXNet's example folder contains a [CTC example](https://github.com/apache/incubator-mxnet/tree/master/example/ctc) for using CTC loss with an LSTM network to perform Optical Character Recognition (OCR) prediction on CAPTCHA images. The example demonstrates use of both CTC loss options, as well as inference after training using network symbol and parameter checkpoints.
+
+## Next Steps
+* [MXNet tutorials index](http://mxnet.io/tutorials/index.html)
diff --git a/docs/tutorials/speech_recognition/speech_lstm.md b/docs/tutorials/speech_recognition/speech_lstm.md
deleted file mode 100644
index 17e2ca0002..0000000000
--- a/docs/tutorials/speech_recognition/speech_lstm.md
+++ /dev/null
@@ -1,156 +0,0 @@
-# Speech LSTM
-You can get the source code for these examples on [GitHub](https://github.com/dmlc/mxnet/tree/master/example/speech-demo).
-
-## Speech Acoustic Modeling Example
-
-The examples folder contains examples for speech recognition:
-
-- [lstm_proj.py](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/lstm_proj.py): Functions for building an LSTM network with and without a projection layer.
-- [io_util.py](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/io_util.py): Wrapper functions for `DataIter` over speech data.
-- [train_lstm_proj.py](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/train_lstm_proj.py): A script for training an LSTM acoustic model.
-- [decode_mxnet.py](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/decode_mxnet.py): A script for decoding an LSTMP acoustic model.
-- [default.cfg](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/default.cfg): Configuration for training on the `AMI` SDM1 dataset. You can use it as a template for writing other configuration files.
-- [python_wrap](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/python_wrap): C wrappers for Kaldi C++ code, built into an .so file. Python code that loads the .so file and calls the C wrapper functions in `io_func/feat_readers/reader_kaldi.py`.
-
-Connect to Kaldi:
-
-- [decode_mxnet.sh](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/decode_mxnet.sh): Called by Kaldi to decode an acoustic model trained by MXNet (select the `simple` method for decoding).
-
-A full receipt:
-
-- [run_ami.sh](https://github.com/dmlc/mxnet/tree/master/example/speech-demo/run_ami.sh): A full receipt to train and decode an acoustic model on AMI. It takes features and alignment from Kaldi to train an acoustic model and decode it.
-
-To create the speech acoustic modeling example, use the following steps.
-
-### Build Kaldi
-
-Build Kaldi as shared libraries if you have not already done so.
-
-```bash
-cd kaldi/src
-./configure --shared # and other options that you need
-make depend
-make
-```
-
-### Build the Python Wrapper
-
-1. Copy or link the attached `python_wrap` folder to `kaldi/src`.
-2. Compile python_wrap/.
-
-```
-cd kaldi/src/python_wrap/
-make
-```
-
-### Extract Features and Prepare Frame-level Labels
-
-The acoustic models use Mel filter-bank or MFCC as input features. They also need to use Kaldi to perform force-alignment to generate frame-level labels from the text transcriptions. For example, if you want to work on the `AMI` data `SDM1`, you can run `kaldi/egs/ami/s5/run_sdm.sh`. Before you can run the examples, you need to configure some paths in `kaldi/egs/ami/s5/cmd.sh` and `kaldi/egs/ami/s5/run_sdm.sh`. Refer to Kaldi's documentation for details.
-
-The default `run_sdm.sh` script generates the force-alignment labels in their stage 7, and saves the force-aligned labels in `exp/sdm1/tri3a_ali`. The default script generates MFCC features (13-dimensional). You can try training with the MFCC features, or you can create Mel filter-bank features by yourself. For example, you can use a script like this to compute Mel filter-bank features using Kaldi:
-
-```bash
-#!/bin/bash -u
-
-. ./cmd.sh
-. ./path.sh
-
-# SDM - Single Distant Microphone
-micid=1 #which mic from array should be used?
-mic=sdm$micid
-
-# Set bash to 'debug' mode, it prints the commands (option '-x') and exits on :
-# -e 'error', -u 'undefined variable', -o pipefail 'error in pipeline',
-set -euxo pipefail
-
-# Path where AMI gets downloaded (or where locally available):
-AMI_DIR=$PWD/wav_db # Default,
-data_dir=$PWD/data/$mic
-
-# make filter bank data
-for dset in train dev eval; do
-  steps/make_fbank.sh --nj 48 --cmd "$train_cmd" $data_dir/$dset \
-    $data_dir/$dset/log $data_dir/$dset/data-fbank
-  steps/compute_cmvn_stats.sh $data_dir/$dset \
-    $data_dir/$dset/log $data_dir/$dset/data
-
-  apply-cmvn --utt2spk=ark:$data_dir/$dset/utt2spk \
-    scp:$data_dir/$dset/cmvn.scp scp:$data_dir/$dset/feats.scp \
-    ark,scp:$data_dir/$dset/feats-cmvn.ark,$data_dir/$dset/feats-cmvn.scp
-
-  mv $data_dir/$dset/feats-cmvn.scp $data_dir/$dset/feats.scp
-done
-```
-`apply-cmvn` provides mean-variance normalization. The default setup was applied per speaker. It's more common to perform mean-variance normalization for the whole corpus, and then feed the results to the neural networks:
-
-```
- compute-cmvn-stats scp:data/sdm1/train_fbank/feats.scp data/sdm1/train_fbank/cmvn_g.ark
- apply-cmvn --norm-vars=true data/sdm1/train_fbank/cmvn_g.ark scp:data/sdm1/train_fbank/feats.scp ark,scp:data/sdm1/train_fbank_gcmvn/feats.ark,data/sdm1/train_fbank_gcmvn/feats.scp
-```
-Note that Kaldi always tries to find features in `feats.scp`. Ensure that the normalized features are organized as Kaldi expects them during decoding.
-
-Finally, put the features and labels together in a file so that MXNet can find them. More specifically, for each data set (train, dev, eval), you will need to create a file similar to `train_mxnet.feats`, with the following contents:
-
-```
-TRANSFORM scp:feat.scp
-scp:label.scp
-```
-
-`TRANSFORM` is the transformation you want to apply to the features. By default, we use `NO_FEATURE_TRANSFORM`. The `scp:` syntax is from Kaldi. `feat.scp` is typically the file from `data/sdm1/train/feats.scp`, and `label.scp` is converted from the force-aligned labels located in `exp/sdm1/tri3a_ali`. Because the force-alignments are generated only on the training data, we split the training set in two, using a 90/10 ratio, and then use the 1/10 holdout as the dev set (validation set). The script [run_ami.sh](https://github.com/dmlc/mxnet/blob/master/example/speech-demo/run_ami.sh) automatically splits and formats the file for MXNet. Before running it, set the path in the script correctly. The [run_ami.sh](https://github.com/dmlc/mxnet/blob/master/example/speech-demo/run_ami.sh) script actually runs the full pipeline, including training the acoustic model and decoding. If the scripts ran successfully, you can skip the following sections.
-
-### Run MXNet Acoustic Model Training
-
-1. Return to the speech demo directory in MXNet. Make a copy of `default.cfg`, and edit the necessary parameters, such as the path to the dataset you just prepared.
-2. Run `python train_lstm.py --configfile=your-config.cfg`. For help, use `python train_lstm.py --help`. You can set all of the configuration parameters in `default.cfg`, the customized config file, and through the command line (e.g., using `--train_batch_size=50`). The latter values overwrite the former ones.
-
-Here are some example outputs from training on the TIMIT dataset:
-
-```
-Example output for TIMIT:
-Summary of dataset ==================
-bucket of len 100 : 3 samples
-bucket of len 200 : 346 samples
-bucket of len 300 : 1496 samples
-bucket of len 400 : 974 samples
-bucket of len 500 : 420 samples
-bucket of len 600 : 90 samples
-bucket of len 700 : 11 samples
-bucket of len 800 : 2 samples
-Summary of dataset ==================
-bucket of len 100 : 0 samples
-bucket of len 200 : 28 samples
-bucket of len 300 : 169 samples
-bucket of len 400 : 107 samples
-bucket of len 500 : 41 samples
-bucket of len 600 : 6 samples
-bucket of len 700 : 3 samples
-bucket of len 800 : 0 samples
-2016-04-21 20:02:40,904 Epoch[0] Train-Acc_exlude_padding=0.154763
-2016-04-21 20:02:40,904 Epoch[0] Time cost=91.574
-2016-04-21 20:02:44,419 Epoch[0] Validation-Acc_exlude_padding=0.353552
-2016-04-21 20:04:17,290 Epoch[1] Train-Acc_exlude_padding=0.447318
-2016-04-21 20:04:17,290 Epoch[1] Time cost=92.870
-2016-04-21 20:04:20,738 Epoch[1] Validation-Acc_exlude_padding=0.506458
-2016-04-21 20:05:53,127 Epoch[2] Train-Acc_exlude_padding=0.557543
-2016-04-21 20:05:53,128 Epoch[2] Time cost=92.390
-2016-04-21 20:05:56,568 Epoch[2] Validation-Acc_exlude_padding=0.548100
-```
-
-The final frame accuracy was approximately 62%.
-
-### Run Decode on the Trained Acoustic Model
-
-1. Estimate senone priors by running `python make_stats.py --configfile=your-config.cfg | copy-feats ark:- ark:label_mean.ark` (edit necessary items, such as the path to the training dataset). This command generates the label counts in `label_mean.ark`.
-2. Link to the necessary Kaldi decode setup, e.g., `local/` and `utils/` and run `./run_ami.sh --model prefix model --num_epoch num`.
-
-Here are the results for the TIMIT and AMI test sets (using the default setup, three-layer LSTM with projection layers):
-
-	| Corpus | WER |
-	|--------|-----|
-	|TIMIT   | 18.9|
-	|AMI     | 51.7 (42.2) |
-
-For AMI 42.2 was evaluated non-overlapped speech. The Kaldi-HMM baseline was 67.2%, and DNN was 57.5%.
-
-## Next Steps
-* [MXNet tutorials index](http://mxnet.io/tutorials/index.html)
diff --git a/example/ctc/README.md b/example/ctc/README.md
index 9035582a53..a2f54cffaf 100644
--- a/example/ctc/README.md
+++ b/example/ctc/README.md
@@ -1,80 +1,113 @@
-# CTC with Mxnet
+# Connectionist Temporal Classification
 
-## Overview
-This example is a modification of [warpctc](https://github.com/dmlc/mxnet/tree/master/example/warpctc)
-It demonstrates the usage of  ```mx.contrib.sym.ctc_loss``` 
+[Connectionist Temporal Classification](https://www.cs.toronto.edu/~graves/icml_2006.pdf) (CTC) is a cost function that is used to train Recurrent Neural Networks (RNNs) to label unsegmented input sequence data in supervised learning. For example in a speech recognition application, using a typical cross-entropy loss the input signal needs to be segmented into words or sub-words. However, using CTC-loss, a single unaligned label sequence per input sequence is sufficient for the network to learn both the alignment and labeling. Baidu's warp-ctc page contains a more detailed [introduction to CTC-loss](https://github.com/baidu-research/warp-ctc#introduction).
 
-## Core code change
+## LSTM OCR Example
+In this example, we use CTC loss to train a network on the problem of Optical Character Recognition (OCR) of CAPTCHA images. This example uses the `captcha` python package to generate a random dataset for training. Training the network requires a CTC-loss layer and MXNet provides two options for such layer. The OCR example is constructed as follows:
 
-The following implementation of ```lstm_unroll```  function is introduced in ```lstm.py``` demonstrates the usage of
-```mx.contrib.sym.ctc_loss```.
+1. 80x30 CAPTCHA images containing 3 to 4 random digits are generated using python captcha library.
+2. Each image is used as a data sequence with sequence-length of 80 and vector length of 30.
+3. The output layer uses CTC loss in training and softmax in inference.
 
-```Cython
-def lstm_unroll(num_lstm_layer, seq_len,
-                num_hidden, num_label):
-    param_cells = []
-    last_states = []
-    for i in range(num_lstm_layer):
-        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
-                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
-                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
-                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
-        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
-                          h=mx.sym.Variable("l%d_init_h" % i))
-        last_states.append(state)
-    assert (len(last_states) == num_lstm_layer)
+Note: When using CTC-loss, one prediction label is reserved for blank label. In this example, when predicting digits between 0 to 9, softmax output has 11 labels, with label 0 used for blank and 1 to 10 used for digit 0 to digit 9 respectively.
 
-    # embeding layer
-    data = mx.sym.Variable('data')
-    label = mx.sym.Variable('label')
-    wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
+### Description of the files
+LSTM-OCR example contains the following files:
+* `captcha_generator.py`: Module for generating random 3 or 4 digit CAPTCHA images for training. It also contains a script for generating sample CAPTCHA images into an output file for inference testing.
+* `ctc_metrics.py`: Module for calculating the prediction accuracy during training. Two accuracy measures are implemented: A simple accuracy measure that calculates number of correct predictions divided by total number of predictions and a second accuracy measure based on sum of Longest Common Sequence (LCS) ratio of all predictions divided by total number of predictions.
+* `hyperparameters.py`: Contains all hyperparameters for the network structure and training.
+* `lstm.py`: Contains LSTM network implementations. Options for adding mxnet-ctc and warp-ctc loss for training as well as adding softmax for inference are available.
+* `lstm_ocr_infer.py`: Script for running inference after training.
+* `lstm_ocr_train.py`: Script for training with ctc or warp-ctc loss.
+* `multiproc_data.py`: A module for multiprocess data generation.
+* `oct_iter.py`: A DataIter module for iterating through training data.
 
-    hidden_all = []
-    for seqidx in range(seq_len):
-        hidden = wordvec[seqidx]
-        for i in range(num_lstm_layer):
-            next_state = lstm(num_hidden, indata=hidden,
-                              prev_state=last_states[i],
-                              param=param_cells[i],
-                              seqidx=seqidx, layeridx=i)
-            hidden = next_state.h
-            last_states[i] = next_state
-        hidden_all.append(hidden)
+## CTC-loss in MXNet
+MXNet supports two CTC-loss layers in Symbol API:
 
-    hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
+* `mxnet.symbol.contrib.ctc_loss` is implemented in MXNet and included as part of the standard package.
+* `mxnet.symbol.WarpCTC` uses Baidu's warp-ctc library and requires building warp-ctc library and mxnet library both from source.
 
-    pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
-    pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0))
+### Building MXNet with warp-ctc
+In order to use `mxnet.symbol.WarpCTC` layer, you need to first build Baidu's [warp-ctc](https://github.com/baidu-research/warp-ctc) library from source and then build MXNet from source with warp-ctc config flags enabled.
 
-    loss = mx.contrib.sym.ctc_loss(data=pred_ctc, label=label)
-    ctc_loss = mx.sym.MakeLoss(loss)
+#### Building warp-ctc
+You need to first build warp-ctc from source and then install it in your system. Please follow [instructions here](https://github.com/baidu-research/warp-ctc#compilation) to build warp-ctc from source. Once compiled, you need to install the library by running the following command from `warp-ctc/build` directory:
+```
+$ sudo make install
+```
 
-    softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc)
-    softmax_loss = mx.sym.MakeLoss(softmax_class)
-    softmax_loss = mx.sym.BlockGrad(softmax_loss)
+#### Building MXNet from source with warp-ctc integration
+In order to build MXNet from source, you need to follow [instructions here](http://mxnet.incubator.apache.org/install/index.html). After choosing your system configuration, Python environment, and "Build from Source" options, before running `make` in step 4, you need to enable warp-ctc integration by uncommenting the following lines in `make/config.mk` in `incubator-mxnet` directory:
+```
+WARPCTC_PATH = $(HOME)/warp-ctc
+MXNET_PLUGINS += plugin/warpctc/warpctc.mk
+```
 
-    return mx.sym.Group([softmax_loss, ctc_loss])
+## Run LSTM OCR Example
+Running this example requires the following pre-requisites:
+* `captcha` and `opencv` python packages are installed:
+```
+$ pip install captcha
+$ pip install opencv-python
+```
+* You have access to one (or more) `ttf` font files. You can download a collection of font files from [Ubuntu's website](https://design.ubuntu.com/font/). The instructions in this section assume that a `./font/Ubuntu-M.ttf` file exists under the `example/ctc/` directory.
+
+### Training
+The training script demonstrates how to construct a network with both CTC loss options and train using `mxnet.Module` API. Training is done by generating random CAPTCHA images using the font(s) provided. This example uses 80x30 captcha images that contain 3 to 4 digits each.
+
+When using a GPU for training, the training bottleneck will be data generation. To remedy this bottleneck, this example implements a multiprocess data generation. Number of processes for image generation as well as training on CPU or GPU can be configured using command line arguments.
+
+To see the list of all arguments:
+```
+$ python lstm_ocr_train.py --help
+```
+Using command line, you can also select between ctc or warp-ctc loss options. For example, the following command initiates a training session on a single GPU with 4 CAPTCHA generating processes using ctc loss and `font/Ubuntu-M.ttf` font file:
 ```
+$ python lstm_ocr_train.py --gpu 1 --num_proc 4 --loss ctc font/Ubuntu-M.ttf
+```
+
+You can train with multiple fonts by specifying a folder that contains multiple `ttf` font files instead. The training saves a checkpoint after each epoch. The prefix used for checkpoint is 'ocr' by default, but can be changed with `--prefix` argument.
 
-## Prerequisites
+When testing this example, the following system configuration was used:
+* p2.xlarge AWS EC2 instance (4 x CPU and 1 x K80 GPU)
+* Deep Learning Amazon Machine Image (with mxnet 1.0.0)
 
-Please ensure that following prerequisites are satisfied before running this examples.
+This training example finishes after 100 epochs with ~87% accuracy. If you continue training further, the network achieves over 95% accuracy. Similar accuracy is achieved with both ctc (`--loss ctc`) and warp-ctc (`--loss warpctc`) options. Logs of the last training epoch:
 
-- ```captcha``` python package is installed.
-- ```cv2``` (or ```openCV```) python package is installed.
-- The test requires font file (```ttf``` format). The user either would need to create ```.\data\```  directory and place the font file in that directory. The user can also edit following line to specify path to the font file.
-```cython
-        # you can get this font from http://font.ubuntu.com/
-        self.captcha = ImageCaptcha(fonts=['./data/Xerox.ttf'])
+```
+05:58:36,128 Epoch[99] Batch [50]	Speed: 1067.63 samples/sec	accuracy=0.877757
+05:58:42,119 Epoch[99] Batch [100]	Speed: 1068.14 samples/sec	accuracy=0.859688
+05:58:48,114 Epoch[99] Batch [150]	Speed: 1067.73 samples/sec	accuracy=0.870469
+05:58:54,107 Epoch[99] Batch [200]	Speed: 1067.91 samples/sec	accuracy=0.864219
+05:58:58,004 Epoch[99] Train-accuracy=0.877367
+05:58:58,005 Epoch[99] Time cost=28.068
+05:58:58,047 Saved checkpoint to "ocr-0100.params"
+05:59:00,721 Epoch[99] Validation-accuracy=0.868886
 ```
 
-## How to run
+### Inference
+The inference script demonstrates how to load a network from a checkpoint, modify its final layer, and predict a label for a CAPTCHA image using `mxnet.Module` API. You can choose the prefix as well as the epoch number of the checkpoint using command line arguments. To see the full list of arguments:
+```
+$ python lstm_ocr_infer.py --help
+```
+For example, to predict label for 'sample.jpg' file using 'ocr' prefix and checkpoint at epoch 100:
+```
+$ python lstm_ocr_infer.py --prefix ocr --epoch 100 sample.jpg
 
-The users would need to run the script ```lstm_ocr.py``` in order to exercise the above code change.
-```cython
-python lstm_ocr.py
-``` 
+Digits: [0, 0, 8, 9]
+```
 
-## Further reading
+Note: The above command expects the following files, generated by the training script, to exist in the current directory:
+* ocr-symbol.json
+* ocr-0100.params
 
-In order to run the ```ocr_predict.py```  please refer to [ReadMe](https://github.com/apache/incubator-mxnet/blob/master/example/warpctc/README.md) file in [warpctc](https://github.com/dmlc/mxnet/tree/master/example/warpctc)
+#### Generate CAPTCHA samples
+CAPTCHA images can be generated using the `captcha_generator.py` script. To see the list of all arguments:
+```
+$ python captcha_generator.py --help
+```
+For example, to generate a CAPTCHA image with random digits from 'font/Ubuntu-M.ttf' and save to 'sample.jpg' file:
+```
+$ python captcha_generator.py font/Ubuntu-M.ttf sample.jpg
+```
diff --git a/example/ctc/captcha_generator.py b/example/ctc/captcha_generator.py
new file mode 100644
index 0000000000..97fab4082e
--- /dev/null
+++ b/example/ctc/captcha_generator.py
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Helper classes for multiprocess captcha image generation
+
+This module also provides script for saving captcha images to file using CLI.
+"""
+
+from __future__ import print_function
+import random
+
+from captcha.image import ImageCaptcha
+import cv2
+from multiproc_data import MPData
+import numpy as np
+
+
+class CaptchaGen(object):
+    """
+    Generates a captcha image
+    """
+    def __init__(self, h, w, font_paths):
+        """
+        Parameters
+        ----------
+        h: int
+            Height of the generated images
+        w: int
+            Width of the generated images
+        font_paths: list of str
+            List of all fonts in ttf format
+        """
+        self.captcha = ImageCaptcha(fonts=font_paths)
+        self.h = h
+        self.w = w
+
+    def image(self, captcha_str):
+        """
+        Generate a greyscale captcha image representing number string
+
+        Parameters
+        ----------
+        captcha_str: str
+            string a characters for captcha image
+
+        Returns
+        -------
+        numpy.ndarray
+            Generated greyscale image in np.ndarray float type with values normalized to [0, 1]
+        """
+        img = self.captcha.generate(captcha_str)
+        img = np.fromstring(img.getvalue(), dtype='uint8')
+        img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
+        img = cv2.resize(img, (self.h, self.w))
+        img = img.transpose(1, 0)
+        img = np.multiply(img, 1 / 255.0)
+        return img
+
+
+class DigitCaptcha(object):
+    """
+    Provides shape() and get() interface for digit-captcha image generation
+    """
+    def __init__(self, font_paths, h, w, num_digit_min, num_digit_max):
+        """
+        Parameters
+        ----------
+        font_paths: list of str
+            List of path to ttf font files
+        h: int
+            height of the generated image
+        w: int
+            width of the generated image
+        num_digit_min: int
+            minimum number of digits in generated captcha image
+        num_digit_max: int
+            maximum number of digits in generated captcha image
+        """
+        self.num_digit_min = num_digit_min
+        self.num_digit_max = num_digit_max
+        self.captcha = CaptchaGen(h=h, w=w, font_paths=font_paths)
+
+    @property
+    def shape(self):
+        """
+        Returns shape of the image data generated
+
+        Returns
+        -------
+        tuple(int, int)
+        """
+        return self.captcha.h, self.captcha.w
+
+    def get(self):
+        """
+        Get an image from the queue
+
+        Returns
+        -------
+        np.ndarray
+            A captcha image, normalized to [0, 1]
+        """
+        return self._gen_sample()
+
+    @staticmethod
+    def get_rand(num_digit_min, num_digit_max):
+        """
+        Generates a character string of digits. Number of digits are
+         between self.num_digit_min and self.num_digit_max
+        Returns
+        -------
+        str
+        """
+        buf = ""
+        max_len = random.randint(num_digit_min, num_digit_max)
+        for i in range(max_len):
+            buf += str(random.randint(0, 9))
+        return buf
+
+    def _gen_sample(self):
+        """
+        Generate a random captcha image sample
+        Returns
+        -------
+        (numpy.ndarray, str)
+            Tuple of image (numpy ndarray) and character string of digits used to generate the image
+        """
+        num_str = self.get_rand(self.num_digit_min, self.num_digit_max)
+        return self.captcha.image(num_str), num_str
+
+
+class MPDigitCaptcha(DigitCaptcha):
+    """
+    Handles multi-process captcha image generation
+    """
+    def __init__(self, font_paths, h, w, num_digit_min, num_digit_max, num_processes, max_queue_size):
+        """
+
+        Parameters
+        ----------
+        font_paths: list of str
+            List of path to ttf font files
+        h: int
+            height of the generated image
+        w: int
+            width of the generated image
+        num_digit_min: int
+            minimum number of digits in generated captcha image
+        num_digit_max: int
+            maximum number of digits in generated captcha image
+        num_processes: int
+            Number of processes to spawn
+        max_queue_size: int
+            Maximum images in queue before processes wait
+        """
+        super(MPDigitCaptcha, self).__init__(font_paths, h, w, num_digit_min, num_digit_max)
+        self.mp_data = MPData(num_processes, max_queue_size, self._gen_sample)
+
+    def start(self):
+        """
+        Starts the processes
+        """
+        self.mp_data.start()
+
+    def get(self):
+        """
+        Get an image from the queue
+
+        Returns
+        -------
+        np.ndarray
+            A captcha image, normalized to [0, 1]
+        """
+        return self.mp_data.get()
+
+    def reset(self):
+        """
+        Resets the generator by stopping all processes
+        """
+        self.mp_data.reset()
+
+
+if __name__ == '__main__':
+    import argparse
+
+    def main():
+        parser = argparse.ArgumentParser()
+        parser.add_argument("font_path", help="Path to ttf font file")
+        parser.add_argument("output", help="Output filename including extension (e.g. 'sample.jpg')")
+        parser.add_argument("--num", help="Up to 4 digit number [Default: random]")
+        args = parser.parse_args()
+
+        captcha = ImageCaptcha(fonts=[args.font_path])
+        captcha_str = args.num if args.num else DigitCaptcha.get_rand(3, 4)
+        img = captcha.generate(captcha_str)
+        img = np.fromstring(img.getvalue(), dtype='uint8')
+        img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
+        cv2.imwrite(args.output, img)
+        print("Captcha image with digits {} written to {}".format([int(c) for c in captcha_str], args.output))
+
+    main()
diff --git a/example/ctc/ctc_metrics.py b/example/ctc/ctc_metrics.py
new file mode 100644
index 0000000000..0db680af18
--- /dev/null
+++ b/example/ctc/ctc_metrics.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Contains a class for calculating CTC eval metrics"""
+
+from __future__ import print_function
+
+import numpy as np
+
+
+class CtcMetrics(object):
+    def __init__(self, seq_len):
+        self.seq_len = seq_len
+
+    @staticmethod
+    def ctc_label(p):
+        """
+        Iterates through p, identifying non-zero and non-repeating values, and returns them in a list
+        Parameters
+        ----------
+        p: list of int
+
+        Returns
+        -------
+        list of int
+        """
+        ret = []
+        p1 = [0] + p
+        for i, _ in enumerate(p):
+            c1 = p1[i]
+            c2 = p1[i+1]
+            if c2 == 0 or c2 == c1:
+                continue
+            ret.append(c2)
+        return ret
+
+    @staticmethod
+    def _remove_blank(l):
+        """ Removes trailing zeros in the list of integers and returns a new list of integers"""
+        ret = []
+        for i, _ in enumerate(l):
+            if l[i] == 0:
+                break
+            ret.append(l[i])
+        return ret
+
+    @staticmethod
+    def _lcs(p, l):
+        """ Calculates the Longest Common Subsequence between p and l (both list of int) and returns its length"""
+        # Dynamic Programming Finding LCS
+        if len(p) == 0:
+            return 0
+        P = np.array(list(p)).reshape((1, len(p)))
+        L = np.array(list(l)).reshape((len(l), 1))
+        M = np.int32(P == L)
+        for i in range(M.shape[0]):
+            for j in range(M.shape[1]):
+                up = 0 if i == 0 else M[i-1, j]
+                left = 0 if j == 0 else M[i, j-1]
+                M[i, j] = max(up, left, M[i, j] if (i == 0 or j == 0) else M[i, j] + M[i-1, j-1])
+        return M.max()
+
+    def accuracy(self, label, pred):
+        """ Simple accuracy measure: number of 100% accurate predictions divided by total number """
+        hit = 0.
+        total = 0.
+        batch_size = label.shape[0]
+        for i in range(batch_size):
+            l = self._remove_blank(label[i])
+            p = []
+            for k in range(self.seq_len):
+                p.append(np.argmax(pred[k * batch_size + i]))
+            p = self.ctc_label(p)
+            if len(p) == len(l):
+                match = True
+                for k, _ in enumerate(p):
+                    if p[k] != int(l[k]):
+                        match = False
+                        break
+                if match:
+                    hit += 1.0
+            total += 1.0
+        assert total == batch_size
+        return hit / total
+
+    def accuracy_lcs(self, label, pred):
+        """ Longest Common Subsequence accuracy measure: calculate accuracy of each prediction as LCS/length"""
+        hit = 0.
+        total = 0.
+        batch_size = label.shape[0]
+        for i in range(batch_size):
+            l = self._remove_blank(label[i])
+            p = []
+            for k in range(self.seq_len):
+                p.append(np.argmax(pred[k * batch_size + i]))
+            p = self.ctc_label(p)
+            hit += self._lcs(p, l) * 1.0 / len(l)
+            total += 1.0
+        assert total == batch_size
+        return hit / total
+
diff --git a/example/ctc/hyperparams.py b/example/ctc/hyperparams.py
new file mode 100644
index 0000000000..7289d19c03
--- /dev/null
+++ b/example/ctc/hyperparams.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Hyperparameters for LSTM OCR Example """
+
+from __future__ import print_function
+
+
+class Hyperparams(object):
+    """
+    Hyperparameters for LSTM network
+    """
+    def __init__(self):
+        # Training hyper parameters
+        self._train_epoch_size = 30000
+        self._eval_epoch_size = 3000
+        self._batch_size = 128
+        self._num_epoch = 100
+        self._learning_rate = 0.001
+        self._momentum = 0.9
+        self._num_label = 4
+        # Network hyper parameters
+        self._seq_length = 80
+        self._num_hidden = 100
+        self._num_lstm_layer = 2
+
+    @property
+    def train_epoch_size(self):
+        return self._train_epoch_size
+
+    @property
+    def eval_epoch_size(self):
+        return self._eval_epoch_size
+
+    @property
+    def batch_size(self):
+        return self._batch_size
+    
+    @property
+    def num_epoch(self):
+        return self._num_epoch
+
+    @property
+    def learning_rate(self):
+        return self._learning_rate
+
+    @property
+    def momentum(self):
+        return self._momentum
+
+    @property
+    def num_label(self):
+        return self._num_label
+
+    @property
+    def seq_length(self):
+        return self._seq_length
+
+    @property
+    def num_hidden(self):
+        return self._num_hidden
+
+    @property
+    def num_lstm_layer(self):
+        return self._num_lstm_layer
diff --git a/example/ctc/lstm.py b/example/ctc/lstm.py
index 326daa1d9f..dcf8b4e4ef 100644
--- a/example/ctc/lstm.py
+++ b/example/ctc/lstm.py
@@ -14,29 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""Contain helpers for creating LSTM symbolic graph for training and inference """
 
-# pylint:skip-file
-import sys
+from __future__ import print_function
 
-from mxnet.symbol_doc import SymbolDoc
+from collections import namedtuple
 
-sys.path.insert(0, "../../python")
 import mxnet as mx
-import numpy as np
-from collections import namedtuple
-import time
-import math
+
+
+__all__ = ["lstm_unroll", "init_states"]
+
 
 LSTMState = namedtuple("LSTMState", ["c", "h"])
 LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
                                      "h2h_weight", "h2h_bias"])
-LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
-                                     "init_states", "last_states",
-                                     "seq_data", "seq_labels", "seq_outputs",
-                                     "param_blocks"])
 
 
-def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx):
+def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx):
     """LSTM Cell symbol"""
     i2h = mx.sym.FullyConnected(data=indata,
                                 weight=param.i2h_weight,
@@ -60,8 +55,8 @@ def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx):
     return LSTMState(c=next_c, h=next_h)
 
 
-def lstm_unroll(num_lstm_layer, seq_len,
-                num_hidden, num_label):
+def _lstm_unroll_base(num_lstm_layer, seq_len, num_hidden):
+    """ Returns symbol for LSTM model up to loss/softmax"""
     param_cells = []
     last_states = []
     for i in range(num_lstm_layer):
@@ -72,35 +67,108 @@ def lstm_unroll(num_lstm_layer, seq_len,
         state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
                           h=mx.sym.Variable("l%d_init_h" % i))
         last_states.append(state)
-    assert (len(last_states) == num_lstm_layer)
+    assert len(last_states) == num_lstm_layer
 
-    # embeding layer
+    # embedding layer
     data = mx.sym.Variable('data')
-    label = mx.sym.Variable('label')
     wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
 
     hidden_all = []
     for seqidx in range(seq_len):
         hidden = wordvec[seqidx]
         for i in range(num_lstm_layer):
-            next_state = lstm(num_hidden, indata=hidden,
-                              prev_state=last_states[i],
-                              param=param_cells[i],
-                              seqidx=seqidx, layeridx=i)
+            next_state = _lstm(
+                num_hidden=num_hidden,
+                indata=hidden,
+                prev_state=last_states[i],
+                param=param_cells[i],
+                seqidx=seqidx,
+                layeridx=i)
             hidden = next_state.h
             last_states[i] = next_state
         hidden_all.append(hidden)
 
     hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
+    pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11, name="pred_fc")
+    return pred_fc
 
-    pred_fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
-    pred_ctc = mx.sym.Reshape(data=pred_fc, shape=(-4, seq_len, -1, 0))
+
+def _add_warp_ctc_loss(pred, seq_len, num_label, label):
+    """ Adds Symbol.contrib.ctc_loss on top of pred symbol and returns the resulting symbol """
+    label = mx.sym.Reshape(data=label, shape=(-1,))
+    label = mx.sym.Cast(data=label, dtype='int32')
+    return mx.sym.WarpCTC(data=pred, label=label, label_length=num_label, input_length=seq_len)
+
+
+def _add_mxnet_ctc_loss(pred, seq_len, label):
+    """ Adds Symbol.WapCTC on top of pred symbol and returns the resulting symbol """
+    pred_ctc = mx.sym.Reshape(data=pred, shape=(-4, seq_len, -1, 0))
 
     loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label)
     ctc_loss = mx.sym.MakeLoss(loss)
 
-    softmax_class = mx.symbol.SoftmaxActivation(data=pred_fc)
+    softmax_class = mx.symbol.SoftmaxActivation(data=pred)
     softmax_loss = mx.sym.MakeLoss(softmax_class)
     softmax_loss = mx.sym.BlockGrad(softmax_loss)
-
     return mx.sym.Group([softmax_loss, ctc_loss])
+
+
+def _add_ctc_loss(pred, seq_len, num_label, loss_type):
+    """ Adds CTC loss on top of pred symbol and returns the resulting symbol """
+    label = mx.sym.Variable('label')
+    if loss_type == 'warpctc':
+        print("Using WarpCTC Loss")
+        sm = _add_warp_ctc_loss(pred, seq_len, num_label, label)
+    else:
+        print("Using MXNet CTC Loss")
+        assert loss_type == 'ctc'
+        sm = _add_mxnet_ctc_loss(pred, seq_len, label)
+    return sm
+
+
+def lstm_unroll(num_lstm_layer, seq_len, num_hidden, num_label, loss_type=None):
+    """
+    Creates an unrolled LSTM symbol for inference if loss_type is not specified, and for training
+    if loss_type is specified. loss_type must be one of 'ctc' or 'warpctc'
+
+    Parameters
+    ----------
+    num_lstm_layer: int
+    seq_len: int
+    num_hidden: int
+    num_label: int
+    loss_type: str
+        'ctc' or 'warpctc'
+
+    Returns
+    -------
+    mxnet.symbol.symbol.Symbol
+    """
+    # Create the base (shared between training and inference) and add loss to the end
+    pred = _lstm_unroll_base(num_lstm_layer, seq_len, num_hidden)
+
+    if loss_type:
+        # Training mode, add loss
+        return _add_ctc_loss(pred, seq_len, num_label, loss_type)
+    else:
+        # Inference mode, add softmax
+        return mx.sym.softmax(data=pred, name='softmax')
+
+
+def init_states(batch_size, num_lstm_layer, num_hidden):
+    """
+    Returns name and shape of init states of LSTM network
+
+    Parameters
+    ----------
+    batch_size: list of tuple of str and tuple of int and int
+    num_lstm_layer: int
+    num_hidden: int
+
+    Returns
+    -------
+    list of tuple of str and tuple of int and int
+    """
+    init_c = [('l%d_init_c' % l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
+    init_h = [('l%d_init_h' % l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
+    return init_c + init_h
diff --git a/example/ctc/lstm_ocr.py b/example/ctc/lstm_ocr.py
deleted file mode 100644
index c9928aa43a..0000000000
--- a/example/ctc/lstm_ocr.py
+++ /dev/null
@@ -1,254 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
-# pylint: disable=superfluous-parens, no-member, invalid-name
-from __future__ import print_function
-import sys, random
-sys.path.insert(0, "../../python")
-import numpy as np
-import mxnet as mx
-
-from lstm import lstm_unroll
-
-from captcha.image import ImageCaptcha
-import cv2, random
-
-
-class SimpleBatch(object):
-    def __init__(self, data_names, data, label_names, label):
-        self.data = data
-        self.label = label
-        self.data_names = data_names
-        self.label_names = label_names
-
-        self.pad = 0
-        self.index = None  # TODO: what is index?
-
-    @property
-    def provide_data(self):
-        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
-
-    @property
-    def provide_label(self):
-        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
-
-
-def gen_rand():
-    buf = ""
-    max_len = random.randint(3, 4)
-    for i in range(max_len):
-        buf += str(random.randint(0, 9))
-    return buf
-
-
-def get_label(buf):
-    ret = np.zeros(4)
-    for i in range(len(buf)):
-        ret[i] = 1 + int(buf[i])
-    if len(buf) == 3:
-        ret[3] = 0
-    return ret
-
-
-class OCRIter(mx.io.DataIter):
-    def __init__(self, count, batch_size, num_label, init_states):
-        super(OCRIter, self).__init__()
-        global SEQ_LENGTH
-        # you can get this font from http://font.ubuntu.com/
-        self.captcha = ImageCaptcha(fonts=['./data/Xerox.ttf'])
-        self.batch_size = batch_size
-        self.count = count
-        self.num_label = num_label
-        self.init_states = init_states
-        self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
-        self.provide_data = [('data', (batch_size, 80, 30))] + init_states
-        self.provide_label = [('label', (self.batch_size, 4))]
-        self.cache_data = []
-        self.cache_label = []
-
-    def __iter__(self):
-        print('iter')
-        init_state_names = [x[0] for x in self.init_states]
-        for k in range(self.count):
-            data = []
-            label = []
-            for i in range(self.batch_size):
-                num = gen_rand()
-                img = self.captcha.generate(num)
-                img = np.fromstring(img.getvalue(), dtype='uint8')
-                img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
-                img = cv2.resize(img, (80, 30))
-                img = img.transpose(1, 0)
-                img = img.reshape((80, 30))
-                img = np.multiply(img, 1 / 255.0)
-                data.append(img)
-                label.append(get_label(num))
-
-            data_all = [mx.nd.array(data)] + self.init_state_arrays
-            label_all = [mx.nd.array(label)]
-            data_names = ['data'] + init_state_names
-            label_names = ['label']
-
-            data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
-            yield data_batch
-
-    def reset(self):
-        self.cache_data.clear()
-        self.cache_label.clear()
-        pass
-
-
-BATCH_SIZE = 1024
-SEQ_LENGTH = 80
-
-
-def ctc_label(p):
-    ret = []
-    p1 = [0] + p
-    for i in range(len(p)):
-        c1 = p1[i]
-        c2 = p1[i + 1]
-        if c2 == 0 or c2 == c1:
-            continue
-        ret.append(c2)
-    return ret
-
-
-def remove_blank(l):
-    ret = []
-    for i in range(len(l)):
-        if l[i] == 0:
-            break
-        ret.append(l[i])
-    return ret
-
-
-def Accuracy(label, pred):
-    global BATCH_SIZE
-    global SEQ_LENGTH
-    hit = 0.
-    total = 0.
-    rp = np.argmax(pred, axis=1)
-    for i in range(BATCH_SIZE):
-        l = remove_blank(label[i])
-        p = []
-        for k in range(SEQ_LENGTH):
-            p.append(np.argmax(pred[k * BATCH_SIZE + i]))
-        p = ctc_label(p)
-        if len(p) == len(l):
-            match = True
-            for k in range(len(p)):
-                if p[k] != int(l[k]):
-                    match = False
-                    break
-            if match:
-                hit += 1.0
-        total += 1.0
-    return hit / total
-
-
-def LCS(p, l):
-    # Dynamic Programming Finding LCS
-    if len(p) == 0:
-        return 0
-    P = np.array(list(p)).reshape((1, len(p)))
-    L = np.array(list(l)).reshape((len(l), 1))
-    M = np.int32(P == L)
-    for i in range(M.shape[0]):
-        for j in range(M.shape[1]):
-            up = 0 if i == 0 else M[i - 1, j]
-            left = 0 if j == 0 else M[i, j - 1]
-            M[i, j] = max(up, left, M[i, j] if (i == 0 or j == 0) else M[i, j] + M[i - 1, j - 1])
-    return M.max()
-
-
-def Accuracy_LCS(label, pred):
-    global BATCH_SIZE
-    global SEQ_LENGTH
-    hit = 0.
-    total = 0.
-    for i in range(BATCH_SIZE):
-        l = remove_blank(label[i])
-        p = []
-        for k in range(SEQ_LENGTH):
-            p.append(np.argmax(pred[k * BATCH_SIZE + i]))
-        p = ctc_label(p)
-        hit += LCS(p, l) * 1.0 / len(l)
-        total += 1.0
-    return hit / total
-
-
-def asum_stat(x):
-    """returns |x|/size(x), async execution."""
-    # npx = x.asnumpy()
-    # print(npx)
-    return x
-    return mx.ndarray.norm(x) / np.sqrt(x.size)
-
-
-if __name__ == '__main__':
-    num_hidden = 100
-    num_lstm_layer = 2
-
-    num_epoch = 100
-    learning_rate = 0.01
-    momentum = 0.9
-    num_label = 4
-
-    contexts = [mx.context.gpu(0)]
-
-
-    def sym_gen(seq_len):
-        return lstm_unroll(num_lstm_layer, seq_len,
-                           num_hidden=num_hidden,
-                           num_label=num_label)
-
-
-    init_c = [('l%d_init_c' % l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
-    init_h = [('l%d_init_h' % l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
-    init_states = init_c + init_h
-
-    data_train = OCRIter(20000, BATCH_SIZE, num_label, init_states)
-    data_val = OCRIter(1000, BATCH_SIZE, num_label, init_states)
-
-    symbol = sym_gen(SEQ_LENGTH)
-
-    import logging
-
-    head = '%(asctime)-15s %(message)s'
-    logging.basicConfig(level=logging.DEBUG, format=head)
-
-    print('begin fit')
-
-    module = mx.mod.Module(symbol, data_names=['data', 'l0_init_c', 'l0_init_h', 'l1_init_c', 'l1_init_h'],
-                           label_names=['label'],
-                           context=contexts)
-
-    module.fit(train_data=data_train,
-               eval_data=data_val,
-               eval_metric=mx.metric.np(Accuracy, allow_extra_outputs=True),
-               optimizer='sgd',
-               optimizer_params={'learning_rate': learning_rate,
-                                 'momentum': momentum,
-                                 'wd': 0.00001,
-                                 },
-               initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
-               num_epoch=num_epoch,
-               batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),
-               epoch_end_callback=mx.callback.do_checkpoint("ocr"),
-               )
diff --git a/example/ctc/lstm_ocr_infer.py b/example/ctc/lstm_ocr_infer.py
new file mode 100644
index 0000000000..80de2c7efa
--- /dev/null
+++ b/example/ctc/lstm_ocr_infer.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" An example of predicting CAPTCHA image data with a LSTM network pre-trained with a CTC loss"""
+
+from __future__ import print_function
+
+import argparse
+
+from ctc_metrics import CtcMetrics
+import cv2
+from hyperparams import Hyperparams
+import lstm
+import mxnet as mx
+import numpy as np
+from ocr_iter import SimpleBatch
+
+
+def read_img(path):
+    """ Reads image specified by path into numpy.ndarray"""
+    img = cv2.resize(cv2.imread(path, 0), (80, 30)).astype(np.float32) / 255
+    img = np.expand_dims(img.transpose(1, 0), 0)
+    return img
+
+
+def lstm_init_states(batch_size):
+    """ Returns a tuple of names and zero arrays for LSTM init states"""
+    hp = Hyperparams()
+    init_shapes = lstm.init_states(batch_size=batch_size, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden)
+    init_names = [s[0] for s in init_shapes]
+    init_arrays = [mx.nd.zeros(x[1]) for x in init_shapes]
+    return init_names, init_arrays
+
+
+def load_module(prefix, epoch, data_names, data_shapes):
+    """
+    Loads the model from checkpoint specified by prefix and epoch, binds it
+    to an executor, and sets its parameters and returns a mx.mod.Module
+    """
+    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+
+    # We don't need CTC loss for prediction, just a simple softmax will suffice.
+    # We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top
+    pred_fc = sym.get_internals()['pred_fc_output']
+    sym = mx.sym.softmax(data=pred_fc)
+
+    mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None)
+    mod.bind(for_training=False, data_shapes=data_shapes)
+    mod.set_params(arg_params, aux_params, allow_missing=False)
+    return mod
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("path", help="Path to the CAPTCHA image file")
+    parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr')
+    parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=100)
+    args = parser.parse_args()
+
+    init_state_names, init_state_arrays = lstm_init_states(batch_size=1)
+    img = read_img(args.path)
+
+    sample = SimpleBatch(
+        data_names=['data'] + init_state_names,
+        data=[mx.nd.array(img)] + init_state_arrays)
+
+    mod = load_module(args.prefix, args.epoch, sample.data_names, sample.provide_data)
+
+    mod.forward(sample)
+    prob = mod.get_outputs()[0].asnumpy()
+
+    prediction = CtcMetrics.ctc_label(np.argmax(prob, axis=-1).tolist())
+    # Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit)
+    prediction = [p - 1 for p in prediction]
+    print("Digits:", prediction)
+    return
+
+
+if __name__ == '__main__':
+    main()
diff --git a/example/ctc/lstm_ocr_train.py b/example/ctc/lstm_ocr_train.py
new file mode 100644
index 0000000000..2c25f7e31e
--- /dev/null
+++ b/example/ctc/lstm_ocr_train.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" An example of using WarpCTC loss for an OCR problem using LSTM and CAPTCHA image data"""
+
+from __future__ import print_function
+
+import argparse
+import logging
+import os
+
+from captcha_generator import MPDigitCaptcha
+from hyperparams import Hyperparams
+from ctc_metrics import CtcMetrics
+import lstm
+import mxnet as mx
+from ocr_iter import OCRIter
+
+
+def get_fonts(path):
+    fonts = list()
+    if os.path.isdir(path):
+        for filename in os.listdir(path):
+            if filename.endswith('.ttf'):
+                fonts.append(os.path.join(path, filename))
+    else:
+        fonts.append(path)
+    return fonts
+
+
+def parse_args():
+    # Parse command line arguments
+    parser = argparse.ArgumentParser()
+    parser.add_argument("font_path", help="Path to ttf font file or directory containing ttf files")
+    parser.add_argument("--loss", help="'ctc' or 'warpctc' loss [Default 'ctc']", default='ctc')
+    parser.add_argument("--cpu",
+                        help="Number of CPUs for training [Default 8]. Ignored if --gpu is specified.",
+                        type=int, default=8)
+    parser.add_argument("--gpu", help="Number of GPUs for training [Default 0]", type=int)
+    parser.add_argument("--num_proc", help="Number CAPTCHA generating processes [Default 4]", type=int, default=4)
+    parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr')
+    return parser.parse_args()
+
+
+def main():
+    args = parse_args()
+    if not any(args.loss == s for s in ['ctc', 'warpctc']):
+        raise ValueError("Invalid loss '{}' (must be 'ctc' or 'warpctc')".format(args.loss))
+
+    hp = Hyperparams()
+
+    # Start a multiprocessor captcha image generator
+    mp_captcha = MPDigitCaptcha(
+        font_paths=get_fonts(args.font_path), h=hp.seq_length, w=30,
+        num_digit_min=3, num_digit_max=4, num_processes=args.num_proc, max_queue_size=hp.batch_size * 2)
+    try:
+        # Must call start() before any call to mxnet module (https://github.com/apache/incubator-mxnet/issues/9213)
+        mp_captcha.start()
+
+        if args.gpu:
+            contexts = [mx.context.gpu(i) for i in range(args.gpu)]
+        else:
+            contexts = [mx.context.cpu(i) for i in range(args.cpu)]
+
+        init_states = lstm.init_states(hp.batch_size, hp.num_lstm_layer, hp.num_hidden)
+
+        data_train = OCRIter(
+            hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='train')
+        data_val = OCRIter(
+            hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='val')
+
+        symbol = lstm.lstm_unroll(
+            num_lstm_layer=hp.num_lstm_layer,
+            seq_len=hp.seq_length,
+            num_hidden=hp.num_hidden,
+            num_label=hp.num_label,
+            loss_type=args.loss)
+
+        head = '%(asctime)-15s %(message)s'
+        logging.basicConfig(level=logging.DEBUG, format=head)
+
+        module = mx.mod.Module(
+            symbol,
+            data_names=['data', 'l0_init_c', 'l0_init_h', 'l1_init_c', 'l1_init_h'],
+            label_names=['label'],
+            context=contexts)
+
+        metrics = CtcMetrics(hp.seq_length)
+        module.fit(train_data=data_train,
+                   eval_data=data_val,
+                   # use metrics.accuracy or metrics.accuracy_lcs
+                   eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
+                   optimizer='sgd',
+                   optimizer_params={'learning_rate': hp.learning_rate,
+                                     'momentum': hp.momentum,
+                                     'wd': 0.00001,
+                                     },
+                   initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
+                   num_epoch=hp.num_epoch,
+                   batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
+                   epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
+                   )
+    except KeyboardInterrupt:
+        print("W: interrupt received, stopping...")
+    finally:
+        # Reset multiprocessing captcha generator to stop processes
+        mp_captcha.reset()
+
+
+if __name__ == '__main__':
+    main()
+
diff --git a/example/ctc/multiproc_data.py b/example/ctc/multiproc_data.py
new file mode 100644
index 0000000000..c5f8da5635
--- /dev/null
+++ b/example/ctc/multiproc_data.py
@@ -0,0 +1,144 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import print_function
+from ctypes import c_bool
+import multiprocessing as mp
+try:
+    from queue import Full as QFullExcept
+    from queue import Empty as QEmptyExcept
+except ImportError:
+    from Queue import Full as QFullExcept
+    from Queue import Empty as QEmptyExcept
+
+import numpy as np
+
+
+class MPData(object):
+    """
+    Handles multi-process data generation.
+
+    Operation:
+        - call start() to start the data generation
+        - call get() (blocking) to read one sample
+        - call reset() to stop data generation
+    """
+    def __init__(self, num_processes, max_queue_size, fn):
+        """
+
+        Parameters
+        ----------
+        num_processes: int
+            Number of processes to spawn
+        max_queue_size: int
+            Maximum samples in the queue before processes wait
+        fn: function
+            function that generates samples, executed on separate processes.
+        """
+        self.queue = mp.Queue(maxsize=int(max_queue_size))
+        self.alive = mp.Value(c_bool, False, lock=False)
+        self.num_proc = num_processes
+        self.proc = list()
+        self.fn = fn
+
+    def start(self):
+        """
+        Starts the processes
+        Parameters
+        ----------
+        fn: function
+
+        """
+        """
+        Starts the processes
+        """
+        self._init_proc()
+
+    @staticmethod
+    def _proc_loop(proc_id, alive, queue, fn):
+        """
+        Thread loop for generating data
+
+        Parameters
+        ----------
+        proc_id: int
+            Process id
+        alive: multiprocessing.Value
+            variable for signaling whether process should continue or not
+        queue: multiprocessing.Queue
+            queue for passing data back
+        fn: function
+            function object that returns a sample to be pushed into the queue
+        """
+        print("proc {} started".format(proc_id))
+        try:
+            while alive.value:
+                data = fn()
+                put_success = False
+                while alive.value and not put_success:
+                    try:
+                        queue.put(data, timeout=0.5)
+                        put_success = True
+                    except QFullExcept:
+                        # print("Queue Full")
+                        pass
+        except KeyboardInterrupt:
+            print("W: interrupt received, stopping process {} ...".format(proc_id))
+        print("Closing process {}".format(proc_id))
+        queue.close()
+
+    def _init_proc(self):
+        """
+        Start processes if not already started
+        """
+        if not self.proc:
+            self.proc = [
+                mp.Process(target=self._proc_loop, args=(i, self.alive, self.queue, self.fn))
+                for i in range(self.num_proc)
+            ]
+            self.alive.value = True
+            for p in self.proc:
+                p.start()
+
+    def get(self):
+        """
+        Get a datum from the queue
+
+        Returns
+        -------
+        np.ndarray
+            A captcha image, normalized to [0, 1]
+        """
+        self._init_proc()
+        return self.queue.get()
+
+    def reset(self):
+        """
+        Resets the generator by stopping all processes
+        """
+        self.alive.value = False
+        qsize = 0
+        try:
+            while True:
+                self.queue.get(timeout=0.1)
+                qsize += 1
+        except QEmptyExcept:
+            pass
+        print("Queue size on reset: {}".format(qsize))
+        for i, p in enumerate(self.proc):
+            p.join()
+        self.proc.clear()
diff --git a/example/ctc/ocr_iter.py b/example/ctc/ocr_iter.py
new file mode 100644
index 0000000000..1432e92a80
--- /dev/null
+++ b/example/ctc/ocr_iter.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Iterator for Captcha images used for LSTM-based OCR model"""
+
+from __future__ import print_function
+
+import numpy as np
+import mxnet as mx
+
+
+class SimpleBatch(object):
+    def __init__(self, data_names, data, label_names=list(), label=list()):
+        self._data = data
+        self._label = label
+        self._data_names = data_names
+        self._label_names = label_names
+
+        self.pad = 0
+        self.index = None  # TODO: what is index?
+
+    @property
+    def data(self):
+        return self._data
+
+    @property
+    def label(self):
+        return self._label
+
+    @property
+    def data_names(self):
+        return self._data_names
+
+    @property
+    def label_names(self):
+        return self._label_names
+
+    @property
+    def provide_data(self):
+        return [(n, x.shape) for n, x in zip(self._data_names, self._data)]
+
+    @property
+    def provide_label(self):
+        return [(n, x.shape) for n, x in zip(self._label_names, self._label)]
+
+
+def get_label(buf):
+    ret = np.zeros(4)
+    for i in range(len(buf)):
+        ret[i] = 1 + int(buf[i])
+    if len(buf) == 3:
+        ret[3] = 0
+    return ret
+
+
+class OCRIter(mx.io.DataIter):
+    """
+    Iterator class for generating captcha image data
+    """
+    def __init__(self, count, batch_size, lstm_init_states, captcha, name):
+        """
+        Parameters
+        ----------
+        count: int
+            Number of batches to produce for one epoch
+        batch_size: int
+        lstm_init_states: list of tuple(str, tuple)
+            A list of tuples with [0] name and [1] shape of each LSTM init state
+        captcha MPCaptcha
+            Captcha image generator. Can be MPCaptcha or any other class providing .shape and .get() interface
+        name: str
+        """
+        super(OCRIter, self).__init__()
+        self.batch_size = batch_size
+        self.count = count
+        self.init_states = lstm_init_states
+        self.init_state_arrays = [mx.nd.zeros(x[1]) for x in lstm_init_states]
+        data_shape = captcha.shape
+        self.provide_data = [('data', (batch_size, data_shape[0], data_shape[1]))] + lstm_init_states
+        self.provide_label = [('label', (self.batch_size, 4))]
+        self.mp_captcha = captcha
+        self.name = name
+
+    def __iter__(self):
+        init_state_names = [x[0] for x in self.init_states]
+        for k in range(self.count):
+            data = []
+            label = []
+            for i in range(self.batch_size):
+                img, num = self.mp_captcha.get()
+                data.append(img)
+                label.append(get_label(num))
+            data_all = [mx.nd.array(data)] + self.init_state_arrays
+            label_all = [mx.nd.array(label)]
+            data_names = ['data'] + init_state_names
+            label_names = ['label']
+
+            data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
+            yield data_batch
diff --git a/example/ctc/ocr_predict.py b/example/ctc/ocr_predict.py
index 3096a664a2..2cf19678f4 100644
--- a/example/ctc/ocr_predict.py
+++ b/example/ctc/ocr_predict.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python2.7
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -16,24 +14,28 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+""" An example of predicting CAPTCHA image data with a LSTM network pre-trained with a CTC loss"""
 
-# coding=utf-8
 from __future__ import print_function
-import sys, os
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.append("../../amalgamation/python/")
-sys.path.append("../../python/")
 
-from mxnet_predict import Predictor
-import mxnet as mx
+import argparse
 
-import numpy as np
+import sys
 import cv2
-import os
+import numpy as np
+import mxnet as mx
+from collections import namedtuple
+from ocr_iter import SimpleBatch
+from captcha_generator import DigitCaptcha
+from ctc_metrics import CtcMetrics
+import lstm
+from hyperparams import Hyperparams
+
 
 class lstm_ocr_model(object):
     # Keep Zero index for blank. (CTC request it)
-    CONST_CHAR='0123456789'
+    CONST_CHAR = '0123456789'
+
     def __init__(self, path_of_json, path_of_params):
         super(lstm_ocr_model, self).__init__()
         self.path_of_json = path_of_json
@@ -52,32 +54,37 @@ def __init_ocr(self):
         init_states = init_c + init_h
 
         init_state_arrays = np.zeros((batch_size, num_hidden), dtype="float32")
-        self.init_state_dict={}
+        self.init_state_dict = {}
         for x in init_states:
             self.init_state_dict[x[0]] = init_state_arrays
 
-        all_shapes = [('data', (batch_size, 80 * 30))] + init_states + [('label', (batch_size, num_label))]
+        all_shapes = [('data', (batch_size, 80, 30))] + init_states + [('label', (batch_size, num_label))]
         all_shapes_dict = {}
         for _shape in all_shapes:
             all_shapes_dict[_shape[0]] = _shape[1]
-        self.predictor = Predictor(open(self.path_of_json).read(),
-                                    open(self.path_of_params).read(),
-                                    all_shapes_dict)
-
-    def forward_ocr(self, img):
-        img = cv2.resize(img, (80, 30))
-        img = img.transpose(1, 0)
-        img = img.reshape((80 * 30))
-        img = np.multiply(img, 1/255.0)
-        self.predictor.forward(data=img, **self.init_state_dict)
+        self.predictor = Predictor(open(self.path_of_json, 'rb').read(),
+                                   open(self.path_of_params, 'rb').read(),
+                                   all_shapes_dict)
+
+    def forward_ocr(self, img_):
+        img_ = cv2.resize(img_, (80, 30))
+        img_ = img_.transpose(1, 0)
+        print(img_.shape)
+        img_ = img_.reshape((1, 80, 30))
+        print(img_.shape)
+        # img_ = img_.reshape((80 * 30))
+        img_ = np.multiply(img_, 1 / 255.0)
+        self.predictor.forward(data=img_, **self.init_state_dict)
         prob = self.predictor.get_output(0)
         label_list = []
         for p in prob:
+            print(np.argsort(p))
             max_index = np.argsort(p)[::-1][0]
             label_list.append(max_index)
         return self.__get_string(label_list)
 
-    def __get_string(self, label_list):
+    @staticmethod
+    def __get_string(label_list):
         # Do CTC label rule
         # CTC cannot emit a repeated symbol on consecutive timesteps
         ret = []
@@ -98,9 +105,55 @@ def __get_string(self, label_list):
             s += c
         return s
 
+
 if __name__ == '__main__':
+    # parser = argparse.ArgumentParser()
+    # parser.add_argument("path", help="Path to the CAPTCHA image file")
+    # parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr')
+    # parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=100)
+    # args = parser.parse_args()
+    #
+    # # Create array of zeros for LSTM init states
+    # hp = Hyperparams()
+    # init_states = lstm.init_states(batch_size=1, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden)
+    # init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
+    # # Read the image into an ndarray
+    # img = cv2.resize(cv2.imread(args.path, 0), (80, 30)).astype(np.float32) / 255
+    # img = np.expand_dims(img.transpose(1, 0), 0)
+    #
+    # data_names = ['data'] + [s[0] for s in init_states]
+    # sample = SimpleBatch(data_names, data=[mx.nd.array(img)] + init_state_arrays)
+    #
+    # sym, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch)
+    #
+    # # We don't need CTC loss for prediction, just a simple softmax will suffice.
+    # # We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top
+    # pred_fc = sym.get_internals()['pred_fc_output']
+    # sym = mx.sym.softmax(data=pred_fc)
+    #
+    # mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None)
+    # mod.bind(for_training=False, data_shapes=sample.provide_data)
+    # mod.set_params(arg_params, aux_params, allow_missing=False)
+    #
+    # mod.forward(sample)
+    # prob = mod.get_outputs()[0].asnumpy()
+    #
+    # label_list = list()
+    # prediction = CtcMetrics.ctc_label(np.argmax(prob, axis=-1).tolist())
+    # # Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit)
+    # prediction = [p - 1 for p in prediction]
+    # print("Digits:", prediction)
+    # exit(0)
+    #
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("predict_lib_path", help="Path to directory containing mxnet_predict.so")
+    args = parser.parse_args()
+
+    sys.path.append(args.predict_lib_path + "/python")
+    from mxnet_predict import Predictor
+
     _lstm_ocr_model = lstm_ocr_model('ocr-symbol.json', 'ocr-0010.params')
-    img = cv2.imread('sample.jpg', 0)
+    img = cv2.imread('sample0.png', 0)
     _str = _lstm_ocr_model.forward_ocr(img)
     print('Result: ', _str)
-
diff --git a/example/warpctc/sample.jpg b/example/ctc/sample.jpg
similarity index 100%
rename from example/warpctc/sample.jpg
rename to example/ctc/sample.jpg
diff --git a/example/warpctc/README.md b/example/warpctc/README.md
deleted file mode 100644
index 9ab56b336a..0000000000
--- a/example/warpctc/README.md
+++ /dev/null
@@ -1,108 +0,0 @@
-# Baidu Warp CTC with Mxnet
-
-Baidu-warpctc is a CTC implement by Baidu which support GPU. CTC can be used with LSTM to solve lable alignment problems in many areas such as OCR, speech recognition.
-
-## Install baidu warpctc
-
-```
-  cd ~/
-  git clone https://github.com/baidu-research/warp-ctc
-  cd warp-ctc
-  mkdir build
-  cd build
-  cmake ..
-  make
-  sudo make install
-```
-
-## Enable warpctc in mxnet
-
-```
-  comment out following lines in make/config.mk
-  WARPCTC_PATH = $(HOME)/warp-ctc
-  MXNET_PLUGINS += plugin/warpctc/warpctc.mk
-  
-  rebuild mxnet by
-  make clean && make -j4
-```
-
-## Run examples
-
-I implement two examples, one is just a toy example which can be used to prove ctc integration is right. The second is a OCR example with LSTM+CTC. You can run it by:
-
-```
-  cd examples/warpctc
-  python lstm_ocr.py
-```
-
-Notes:
-* Please modify ```contexts = [mx.context.gpu(0)]``` in this file according to your hardware.
-* Please review the code ```'./font/Ubuntu-M.ttf'```. Copy your font to here font/yourfont.ttf. To get a free font from [here](http://font.ubuntu.com/).
-* The checkpoint will be auto saved in each epoch. And then you can use this checkpoint to do a predict.
-
-The OCR example is constructed as follows:
-  
-1. I generate 80x30 image for 4 digits captcha by an python captcha library
-2. The 80x30 image is used as 80 input for lstm and every input is one column of image (a 30 dim vector)
-3. The output layer use CTC loss
-
-Following code show detail construction of the net:
-
-```
-  def lstm_unroll(num_lstm_layer, seq_len,
-                  num_hidden, num_label):
-    param_cells = []
-    last_states = []
-    for i in range(num_lstm_layer):
-        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
-                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
-                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
-                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
-        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
-                          h=mx.sym.Variable("l%d_init_h" % i))
-        last_states.append(state)
-    assert(len(last_states) == num_lstm_layer)
-    data = mx.sym.Variable('data')
-    label = mx.sym.Variable('label')
-    
-    #every column of image is an input, there are seq_len inputs
-    wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
-    hidden_all = []
-    for seqidx in range(seq_len):
-        hidden = wordvec[seqidx]
-        for i in range(num_lstm_layer):
-            next_state = lstm(num_hidden, indata=hidden,
-                              prev_state=last_states[i],
-                              param=param_cells[i],
-                              seqidx=seqidx, layeridx=i)
-            hidden = next_state.h
-            last_states[i] = next_state
-        hidden_all.append(hidden)
-    hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
-    pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
-    
-    # here we do NOT need to transpose label as other lstm examples do
-    label = mx.sym.Reshape(data=label, target_shape=(0,))
-    #label should be int type, so use cast
-    label = mx.sym.Cast(data = label, dtype = 'int32')
-    sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
-    return sm
-```
-  
-## Support multi label length
-
-If you label length is smaller than or equal to b. You should provide labels with length b, and for those samples which label length is smaller than b, you should append 0 to label data to make it have length b.
-
-Here, 0 is reserved for blank label.
-
-## Do a predict
-
-Pelase run:
-
-```
-python ocr_predict.py
-```
-
-Notes:
-* Change the code following the name of your params and json file.
-* You have to do a ```make``` in amalgamation folder.(a libmxnet_predict.so will be created in lib folder.)
diff --git a/example/warpctc/infer_ocr.py b/example/warpctc/infer_ocr.py
deleted file mode 100644
index d469990ff9..0000000000
--- a/example/warpctc/infer_ocr.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# coding=utf-8
-# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
-# pylint: disable=superfluous-parens, no-member, invalid-name
-import sys
-
-sys.path.insert(0, "../../python")
-from __future__ import print_function
-import numpy as np
-import mxnet as mx
-
-from lstm_model import LSTMInferenceModel
-
-import cv2, random
-from captcha.image import ImageCaptcha
-
-BATCH_SIZE = 32
-SEQ_LENGTH = 80
-
-
-def ctc_label(p):
-    ret = []
-    p1 = [0] + p
-    for i in range(len(p)):
-        c1 = p1[i]
-        c2 = p1[i + 1]
-        if c2 == 0 or c2 == c1:
-            continue
-        ret.append(c2)
-    return ret
-
-
-def remove_blank(l):
-    ret = []
-    for i in range(len(l)):
-        if l[i] == 0:
-            break
-        ret.append(l[i])
-    return ret
-
-
-def gen_rand():
-    buf = ""
-    max_len = random.randint(3,4)
-    for i in range(max_len):
-        buf += str(random.randint(0,9))
-    return buf
-
-if __name__ == '__main__':
-    num_hidden = 100
-    num_lstm_layer = 2
-
-    num_epoch = 10
-    learning_rate = 0.001
-    momentum = 0.9
-    num_label = 4
-
-    n_channel = 1
-    contexts = [mx.context.gpu(0)]
-    _, arg_params, __ = mx.model.load_checkpoint('ocr', num_epoch)
-
-    num = gen_rand()
-    print('Generated number: ' + num)
-    # change the fonts accordingly
-    captcha = ImageCaptcha(fonts=['./data/OpenSans-Regular.ttf'])
-    img = captcha.generate(num)
-    img = np.fromstring(img.getvalue(), dtype='uint8')
-    img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
-    img = cv2.resize(img, (80, 30))
-
-    img = img.transpose(1, 0)
-
-    img = img.reshape((1, 80 * 30))
-    img = np.multiply(img, 1 / 255.0)
-
-    data_shape = [('data', (1, n_channel * 80 * 30))]
-    input_shapes = dict(data_shape)
-
-    model = LSTMInferenceModel(num_lstm_layer,
-                               SEQ_LENGTH,
-                               num_hidden=num_hidden,
-                               num_label=num_label,
-                               arg_params=arg_params,
-                               data_size = n_channel * 30 * 80,
-                               ctx=contexts[0])
-
-    prob = model.forward(mx.nd.array(img))
-
-    p = []
-    for k in range(SEQ_LENGTH):
-        p.append(np.argmax(prob[k]))
-
-    p = ctc_label(p)
-    print('Predicted label: ' + str(p))
-
-    pred = ''
-    for c in p:
-        pred += str((int(c) - 1))
-
-    print('Predicted number: ' + pred)
-
-
diff --git a/example/warpctc/lstm.py b/example/warpctc/lstm.py
deleted file mode 100644
index 9e0e05c901..0000000000
--- a/example/warpctc/lstm.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint:skip-file
-import sys
-sys.path.insert(0, "../../python")
-import mxnet as mx
-import numpy as np
-from collections import namedtuple
-import time
-import math
-LSTMState = namedtuple("LSTMState", ["c", "h"])
-LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
-                                     "h2h_weight", "h2h_bias"])
-LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
-                                     "init_states", "last_states",
-                                     "seq_data", "seq_labels", "seq_outputs",
-                                     "param_blocks"])
-
-def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx):
-    """LSTM Cell symbol"""
-    i2h = mx.sym.FullyConnected(data=indata,
-                                weight=param.i2h_weight,
-                                bias=param.i2h_bias,
-                                num_hidden=num_hidden * 4,
-                                name="t%d_l%d_i2h" % (seqidx, layeridx))
-    h2h = mx.sym.FullyConnected(data=prev_state.h,
-                                weight=param.h2h_weight,
-                                bias=param.h2h_bias,
-                                num_hidden=num_hidden * 4,
-                                name="t%d_l%d_h2h" % (seqidx, layeridx))
-    gates = i2h + h2h
-    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
-                                      name="t%d_l%d_slice" % (seqidx, layeridx))
-    in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
-    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
-    forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
-    out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
-    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
-    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
-    return LSTMState(c=next_c, h=next_h)
-
-
-def lstm_unroll(num_lstm_layer, seq_len,
-                num_hidden, num_label):
-    param_cells = []
-    last_states = []
-    for i in range(num_lstm_layer):
-        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
-                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
-                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
-                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
-        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
-                          h=mx.sym.Variable("l%d_init_h" % i))
-        last_states.append(state)
-    assert(len(last_states) == num_lstm_layer)
-
-    # embeding layer
-    data = mx.sym.Variable('data')
-    label = mx.sym.Variable('label')
-    wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
-
-    hidden_all = []
-    for seqidx in range(seq_len):
-        hidden = wordvec[seqidx]
-        for i in range(num_lstm_layer):
-            next_state = lstm(num_hidden, indata=hidden,
-                              prev_state=last_states[i],
-                              param=param_cells[i],
-                              seqidx=seqidx, layeridx=i)
-            hidden = next_state.h
-            last_states[i] = next_state
-        hidden_all.append(hidden)
-
-    hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
-    pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
-
-    label = mx.sym.Reshape(data=label, shape=(-1,))
-    label = mx.sym.Cast(data = label, dtype = 'int32')
-    sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
-    return sm
-
-
-def lstm_inference_symbol(num_lstm_layer, seq_len, num_hidden, num_label):
-    param_cells = []
-    last_states = []
-    for i in range(num_lstm_layer):
-        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
-                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
-                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
-                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
-        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
-                          h=mx.sym.Variable("l%d_init_h" % i))
-        last_states.append(state)
-    assert (len(last_states) == num_lstm_layer)
-
-    # embeding layer
-    data = mx.sym.Variable('data')
-    wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
-
-    hidden_all = []
-    for seqidx in range(seq_len):
-        hidden = wordvec[seqidx]
-        for i in range(num_lstm_layer):
-            next_state = lstm(num_hidden, indata=hidden,
-                              prev_state=last_states[i],
-                              param=param_cells[i],
-                              seqidx=seqidx, layeridx=i)
-            hidden = next_state.h
-            last_states[i] = next_state
-        hidden_all.append(hidden)
-
-    hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
-    fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
-    sm = mx.sym.SoftmaxOutput(data=fc, name='softmax')
-
-    output = [sm]
-    for state in last_states:
-        output.append(state.c)
-        output.append(state.h)
-    return mx.sym.Group(output)
diff --git a/example/warpctc/lstm_model.py b/example/warpctc/lstm_model.py
deleted file mode 100644
index d359f1ae5a..0000000000
--- a/example/warpctc/lstm_model.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
-# pylint: disable=superfluous-parens, no-member, invalid-name
-import sys
-sys.path.insert(0, "../../python")
-import numpy as np
-import mxnet as mx
-
-from lstm import LSTMState, LSTMParam, lstm, lstm_inference_symbol
-
-
-class LSTMInferenceModel(object):
-    def __init__(self,
-                 num_lstm_layer,
-                 seq_len,
-                 num_hidden,
-                 num_label,
-                 arg_params,
-                 data_size,
-                 ctx=mx.cpu()):
-        self.sym = lstm_inference_symbol(num_lstm_layer,
-                                         seq_len,
-                                         num_hidden,
-                                         num_label)
-
-        batch_size = 1
-        init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
-        init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
-        data_shape = [("data", (batch_size, data_size))]
-        input_shapes = dict(init_c + init_h + data_shape)
-        self.executor = self.sym.simple_bind(ctx=ctx, **input_shapes)
-
-        for key in self.executor.arg_dict.keys():
-            if key in arg_params:
-                arg_params[key].copyto(self.executor.arg_dict[key])
-
-        state_name = []
-        for i in range(num_lstm_layer):
-            state_name.append("l%d_init_c" % i)
-            state_name.append("l%d_init_h" % i)
-
-        self.states_dict = dict(zip(state_name, self.executor.outputs[1:]))
-        self.input_arr = mx.nd.zeros(data_shape[0][1])
-
-    def forward(self, input_data, new_seq=False):
-        if new_seq == True:
-            for key in self.states_dict.keys():
-                self.executor.arg_dict[key][:] = 0.
-        input_data.copyto(self.executor.arg_dict["data"])
-        self.executor.forward()
-        for key in self.states_dict.keys():
-            self.states_dict[key].copyto(self.executor.arg_dict[key])
-        prob = self.executor.outputs[0].asnumpy()
-        return prob
diff --git a/example/warpctc/lstm_ocr.py b/example/warpctc/lstm_ocr.py
deleted file mode 100644
index 9dd39efb49..0000000000
--- a/example/warpctc/lstm_ocr.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
-# pylint: disable=superfluous-parens, no-member, invalid-name
-from __future__ import print_function
-import sys, random
-sys.path.insert(0, "../../python")
-import numpy as np
-import mxnet as mx
-
-from lstm import lstm_unroll
-
-from io import BytesIO
-from captcha.image import ImageCaptcha
-import cv2, random
-
-class SimpleBatch(object):
-    def __init__(self, data_names, data, label_names, label):
-        self.data = data
-        self.label = label
-        self.data_names = data_names
-        self.label_names = label_names
-
-        self.pad = 0
-        self.index = None # TODO: what is index?
-
-    @property
-    def provide_data(self):
-        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
-
-    @property
-    def provide_label(self):
-        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
-
-def gen_rand():
-    buf = ""
-    max_len = random.randint(3,4)
-    for i in range(max_len):
-        buf += str(random.randint(0,9))
-    return buf
-
-def get_label(buf):
-    ret = np.zeros(4)
-    for i in range(len(buf)):
-        ret[i] = 1 + int(buf[i])
-    if len(buf) == 3:
-        ret[3] = 0
-    return ret
-
-class OCRIter(mx.io.DataIter):
-    def __init__(self, count, batch_size, num_label, init_states):
-        super(OCRIter, self).__init__()
-        # you can get this font from http://font.ubuntu.com/
-        self.captcha = ImageCaptcha(fonts=['./font/Ubuntu-M.ttf'])
-        self.batch_size = batch_size
-        self.count = count
-        self.num_label = num_label
-        self.init_states = init_states
-        self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
-        self.provide_data = [('data', (batch_size, 80, 30))] + init_states
-        self.provide_label = [('label', (self.batch_size, 4))]
-
-    def __iter__(self):
-        print('iter')
-        init_state_names = [x[0] for x in self.init_states]
-        for k in range(self.count):
-            data = []
-            label = []
-            for i in range(self.batch_size):
-                num = gen_rand()
-                img = self.captcha.generate(num)
-                img = np.fromstring(img.getvalue(), dtype='uint8')
-                img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
-                img = cv2.resize(img, (80, 30))
-                img = img.transpose(1, 0)
-                img = img.reshape((80, 30))
-                img = np.multiply(img, 1/255.0)
-                data.append(img)
-                label.append(get_label(num))
-
-            data_all = [mx.nd.array(data)] + self.init_state_arrays
-            label_all = [mx.nd.array(label)]
-            data_names = ['data'] + init_state_names
-            label_names = ['label']
-
-
-            data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
-            yield data_batch
-
-    def reset(self):
-        pass
-
-BATCH_SIZE = 32
-SEQ_LENGTH = 80
-
-def ctc_label(p):
-    ret = []
-    p1 = [0] + p
-    for i in range(len(p)):
-        c1 = p1[i]
-        c2 = p1[i+1]
-        if c2 == 0 or c2 == c1:
-            continue
-        ret.append(c2)
-    return ret
-
-def remove_blank(l):
-    ret = []
-    for i in range(len(l)):
-        if l[i] == 0:
-            break
-        ret.append(l[i])
-    return ret
-
-def Accuracy(label, pred):
-    global BATCH_SIZE
-    global SEQ_LENGTH
-    hit = 0.
-    total = 0.
-    for i in range(BATCH_SIZE):
-        l = remove_blank(label[i])
-        p = []
-        for k in range(SEQ_LENGTH):
-            p.append(np.argmax(pred[k * BATCH_SIZE + i]))
-        p = ctc_label(p)
-        if len(p) == len(l):
-            match = True
-            for k in range(len(p)):
-                if p[k] != int(l[k]):
-                    match = False
-                    break
-            if match:
-                hit += 1.0
-        total += 1.0
-    return hit / total
-
-def LCS(p,l):
-    # Dynamic Programming Finding LCS
-    if len(p) == 0:
-        return 0
-    P = np.array(list(p)).reshape((1, len(p)))
-    L = np.array(list(l)).reshape((len(l), 1))
-    M = np.int32(P == L)
-    for i in range(M.shape[0]):
-        for j in range(M.shape[1]):
-            up = 0 if i == 0 else M[i-1,j]
-            left = 0 if j == 0 else M[i,j-1]
-            M[i,j] = max(up, left, M[i,j] if (i == 0 or j == 0) else M[i,j] + M[i-1,j-1])
-    return M.max()
-
-
-def Accuracy_LCS(label, pred):
-    global BATCH_SIZE
-    global SEQ_LENGTH
-    hit = 0.
-    total = 0.
-    for i in range(BATCH_SIZE):
-        l = remove_blank(label[i])
-        p = []
-        for k in range(SEQ_LENGTH):
-            p.append(np.argmax(pred[k * BATCH_SIZE + i]))
-        p = ctc_label(p)
-        hit += LCS(p,l) * 1.0 / len(l)
-        total += 1.0
-    return hit / total
-
-if __name__ == '__main__':
-    num_hidden = 100
-    num_lstm_layer = 2
-
-    num_epoch = 10
-    learning_rate = 0.001
-    momentum = 0.9
-    num_label = 4
-
-    contexts = [mx.context.gpu(0)]
-
-    def sym_gen(seq_len):
-        return lstm_unroll(num_lstm_layer, seq_len,
-                           num_hidden=num_hidden,
-                           num_label = num_label)
-
-    init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
-    init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
-    init_states = init_c + init_h
-
-    data_train = OCRIter(10000, BATCH_SIZE, num_label, init_states)
-    data_val = OCRIter(1000, BATCH_SIZE, num_label, init_states)
-
-    symbol = sym_gen(SEQ_LENGTH)
-
-    model = mx.model.FeedForward(ctx=contexts,
-                                 symbol=symbol,
-                                 num_epoch=num_epoch,
-                                 learning_rate=learning_rate,
-                                 momentum=momentum,
-                                 wd=0.00001,
-                                 initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
-
-    import logging
-    head = '%(asctime)-15s %(message)s'
-    logging.basicConfig(level=logging.DEBUG, format=head)
-
-    print('begin fit')
-
-    prefix = 'ocr'
-    model.fit(X=data_train, eval_data=data_val,
-              eval_metric = mx.metric.np(Accuracy),
-              # Use the following eval_metric if your num_label >= 10, or varies in a wide range
-              # eval_metric = mx.metric.np(Accuracy_LCS),
-              batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),
-              epoch_end_callback = mx.callback.do_checkpoint(prefix, 1))
-
-    model.save(prefix)
diff --git a/example/warpctc/ocr_predict.py b/example/warpctc/ocr_predict.py
deleted file mode 100644
index 3096a664a2..0000000000
--- a/example/warpctc/ocr_predict.py
+++ /dev/null
@@ -1,106 +0,0 @@
-#!/usr/bin/env python2.7
-
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# coding=utf-8
-from __future__ import print_function
-import sys, os
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.append("../../amalgamation/python/")
-sys.path.append("../../python/")
-
-from mxnet_predict import Predictor
-import mxnet as mx
-
-import numpy as np
-import cv2
-import os
-
-class lstm_ocr_model(object):
-    # Keep Zero index for blank. (CTC request it)
-    CONST_CHAR='0123456789'
-    def __init__(self, path_of_json, path_of_params):
-        super(lstm_ocr_model, self).__init__()
-        self.path_of_json = path_of_json
-        self.path_of_params = path_of_params
-        self.predictor = None
-        self.__init_ocr()
-
-    def __init_ocr(self):
-        num_label = 4 # Set your max length of label, add one more for blank
-        batch_size = 1
-
-        num_hidden = 100
-        num_lstm_layer = 2
-        init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
-        init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
-        init_states = init_c + init_h
-
-        init_state_arrays = np.zeros((batch_size, num_hidden), dtype="float32")
-        self.init_state_dict={}
-        for x in init_states:
-            self.init_state_dict[x[0]] = init_state_arrays
-
-        all_shapes = [('data', (batch_size, 80 * 30))] + init_states + [('label', (batch_size, num_label))]
-        all_shapes_dict = {}
-        for _shape in all_shapes:
-            all_shapes_dict[_shape[0]] = _shape[1]
-        self.predictor = Predictor(open(self.path_of_json).read(),
-                                    open(self.path_of_params).read(),
-                                    all_shapes_dict)
-
-    def forward_ocr(self, img):
-        img = cv2.resize(img, (80, 30))
-        img = img.transpose(1, 0)
-        img = img.reshape((80 * 30))
-        img = np.multiply(img, 1/255.0)
-        self.predictor.forward(data=img, **self.init_state_dict)
-        prob = self.predictor.get_output(0)
-        label_list = []
-        for p in prob:
-            max_index = np.argsort(p)[::-1][0]
-            label_list.append(max_index)
-        return self.__get_string(label_list)
-
-    def __get_string(self, label_list):
-        # Do CTC label rule
-        # CTC cannot emit a repeated symbol on consecutive timesteps
-        ret = []
-        label_list2 = [0] + list(label_list)
-        for i in range(len(label_list)):
-            c1 = label_list2[i]
-            c2 = label_list2[i+1]
-            if c2 == 0 or c2 == c1:
-                continue
-            ret.append(c2)
-        # change to ascii
-        s = ''
-        for l in ret:
-            if l > 0 and l < (len(lstm_ocr_model.CONST_CHAR)+1):
-                c = lstm_ocr_model.CONST_CHAR[l-1]
-            else:
-                c = ''
-            s += c
-        return s
-
-if __name__ == '__main__':
-    _lstm_ocr_model = lstm_ocr_model('ocr-symbol.json', 'ocr-0010.params')
-    img = cv2.imread('sample.jpg', 0)
-    _str = _lstm_ocr_model.forward_ocr(img)
-    print('Result: ', _str)
-
diff --git a/example/warpctc/toy_ctc.py b/example/warpctc/toy_ctc.py
deleted file mode 100644
index c7b0ccc3df..0000000000
--- a/example/warpctc/toy_ctc.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
-# pylint: disable=superfluous-parens, no-member, invalid-name
-from __future__ import print_function
-import sys
-sys.path.insert(0, "../../python")
-import numpy as np
-import mxnet as mx
-import random
-from lstm import lstm_unroll
-
-class SimpleBatch(object):
-    def __init__(self, data_names, data, label_names, label):
-        self.data = data
-        self.label = label
-        self.data_names = data_names
-        self.label_names = label_names
-
-        self.pad = 0
-        self.index = None # TODO: what is index?
-
-    @property
-    def provide_data(self):
-        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
-
-    @property
-    def provide_label(self):
-        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
-
-def gen_feature(n):
-    ret = np.zeros(10)
-    ret[n] = 1
-    return ret
-
-def gen_rand():
-    num = random.randint(0, 9999)
-    buf = str(num)
-    while len(buf) < 4:
-        buf = "0" + buf
-    ret = []
-    for i in range(80):
-        c = int(buf[i // 20])
-        ret.append(gen_feature(c))
-    return buf, ret
-
-def get_label(buf):
-    ret = np.zeros(4)
-    for i in range(4):
-        ret[i] = 1 + int(buf[i])
-    return ret
-
-class DataIter(mx.io.DataIter):
-    def __init__(self, count, batch_size, num_label, init_states):
-        super(DataIter, self).__init__()
-        self.batch_size = batch_size
-        self.count = count
-        self.num_label = num_label
-        self.init_states = init_states
-        self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
-        self.provide_data = [('data', (batch_size, 80, 10))] + init_states
-        self.provide_label = [('label', (self.batch_size, 4))]
-
-    def __iter__(self):
-        init_state_names = [x[0] for x in self.init_states]
-        for k in range(self.count):
-            data = []
-            label = []
-            for i in range(self.batch_size):
-                num, img = gen_rand()
-                data.append(img)
-                label.append(get_label(num))
-
-            data_all = [mx.nd.array(data)] + self.init_state_arrays
-            label_all = [mx.nd.array(label)]
-            data_names = ['data'] + init_state_names
-            label_names = ['label']
-
-
-            data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
-            yield data_batch
-
-    def reset(self):
-        pass
-
-BATCH_SIZE = 32
-SEQ_LENGTH = 80
-
-def ctc_label(p):
-    ret = []
-    p1 = [0] + p
-    for i in range(len(p)):
-        c1 = p1[i]
-        c2 = p1[i+1]
-        if c2 == 0 or c2 == c1:
-            continue
-        ret.append(c2)
-    return ret
-
-
-def Accuracy(label, pred):
-    global BATCH_SIZE
-    global SEQ_LENGTH
-    hit = 0.
-    total = 0.
-    for i in range(BATCH_SIZE):
-        l = label[i]
-        p = []
-        for k in range(SEQ_LENGTH):
-            p.append(np.argmax(pred[k * BATCH_SIZE + i]))
-        p = ctc_label(p)
-        if len(p) == len(l):
-            match = True
-            for k in range(len(p)):
-                if p[k] != int(l[k]):
-                    match = False
-                    break
-            if match:
-                hit += 1.0
-        total += 1.0
-    return hit / total
-
-if __name__ == '__main__':
-    num_hidden = 100
-    num_lstm_layer = 1
-
-    num_epoch = 10
-    learning_rate = 0.001
-    momentum = 0.9
-    num_label = 4
-
-    contexts = [mx.context.gpu(0)]
-
-    def sym_gen(seq_len):
-        return lstm_unroll(num_lstm_layer, seq_len,
-                           num_hidden=num_hidden,
-                           num_label = num_label)
-
-    init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
-    init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
-    init_states = init_c + init_h
-
-    data_train = DataIter(100000, BATCH_SIZE, num_label, init_states)
-    data_val = DataIter(1000, BATCH_SIZE, num_label, init_states)
-
-    symbol = sym_gen(SEQ_LENGTH)
-
-    model = mx.model.FeedForward(ctx=contexts,
-                                 symbol=symbol,
-                                 num_epoch=num_epoch,
-                                 learning_rate=learning_rate,
-                                 momentum=momentum,
-                                 wd=0.00001,
-                                 initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
-
-    import logging
-    head = '%(asctime)-15s %(message)s'
-    logging.basicConfig(level=logging.DEBUG, format=head)
-
-    print('begin fit')
-
-    model.fit(X=data_train, eval_data=data_val,
-              eval_metric = mx.metric.np(Accuracy),
-              batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),)
-
-    model.save("ocr")


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services