You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by aa...@apache.org on 2019/02/13 18:23:41 UTC

[incubator-mxnet] branch master updated: Update lip reading example (#13647)

This is an automated email from the ASF dual-hosted git repository.

aaronmarkham pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ff6ad1  Update lip reading example (#13647)
7ff6ad1 is described below

commit 7ff6ad1586b21256a27aaebb34b6ac84e6db0e0b
Author: seujung hwan, Jung <di...@gmail.com>
AuthorDate: Thu Feb 14 03:23:25 2019 +0900

    Update lip reading example (#13647)
    
    * update lipnet
    
    * update utils
    
    * Update example/gluon/lipnet/README.md
    
    Co-Authored-By: seujung <di...@gmail.com>
    
    * Update example/gluon/lipnet/README.md
    
    Co-Authored-By: seujung <di...@gmail.com>
    
    * Update example/gluon/lipnet/utils/multi.py
    
    Co-Authored-By: seujung <di...@gmail.com>
    
    * Update example/gluon/lipnet/utils/preprocess_data.py
    
    Co-Authored-By: seujung <di...@gmail.com>
    
    * Update example/gluon/lipnet/utils/multi.py
    
    Co-Authored-By: seujung <di...@gmail.com>
    
    * Update example/gluon/lipnet/utils/download_data.py
    
    Co-Authored-By: seujung <di...@gmail.com>
    
    * fix error for using gpu mode
    
    * Add requirements
    
    * Remove unnecessary requirements
    
    * Update .gitignore
    
    * Remove inappropriate license file
    
    * Changed relative path
    
    * Fix description
    
    * Fix description
    
    * Fix description
    
    * Fix description
    
    * Change doc strings and add url reference
    
    * Fix align_path
    
    * Remove zip files
    
    * Fix bugs: source_path, n_process
    
    * Fix target_path
    
    * Fix exception handler and resume the preprocess
    
    * Pass the output when it fails to detect the mouth
    
    * Add exception during collecting images
    
    * Add the disk space and fix default align_path
    
    * Change optimizer
    
    * Update readme for pip
    
    * Update README
    
    * Add checkpoint folder
    
    * Apply to train using multiprocess
    
    * update network.py
    
    * delete batchnorm comment
    *fix dropout
    * fix loading ndarray as F
    * add space
    
    * Update readme
    
    * Add the info of GRID Data
    * Add the info of word alignments
    * Add total download size
    * Add time for preprocessing
    
    * Add test code for beamsearch
    
    * add space
    
    * delete line and fix code
    
    * Add shebang in BeamSearch
    
    * Fix trainer
    
    * Add space line
    
    * Fix appeding losses
    
    * Fix trainer
    
    * Delete debug line in data_loader
    
    * Move transpose of input into data_loader
    
    * Delete trailing-whitespace
    
    * Hybridize lip model
    
    * Hybridize model
    
    * Refactor the len of input sequence
    
    * Fix the shape of model
    
    * Apply to split train and validation
    
    * Split data into train and valid
    
    * Update Readme
    
    * Add infer.py
    
    * Remove ipynb
    
    * Apply to continual learning
    
    * Add images
    
    * Update readme
    
    * Fix typo and pylint
    
    * Fix loss digits of save_file and typo
    
    * Add info of data split and batch size
---
 example/gluon/lipnet/.gitignore                    |   3 +
 example/gluon/lipnet/BeamSearch.py                 | 170 ++++++++++
 example/gluon/lipnet/README.md                     | 194 +++++++++++
 example/gluon/lipnet/asset/mouth_000.png           | Bin 0 -> 6372 bytes
 example/gluon/lipnet/asset/mouth_001.png           | Bin 0 -> 6826 bytes
 example/gluon/lipnet/asset/mouth_074.png           | Bin 0 -> 6864 bytes
 example/gluon/lipnet/asset/network_structure.png   | Bin 0 -> 183728 bytes
 example/gluon/lipnet/asset/s2_bbbf7p_000.png       | Bin 0 -> 35141 bytes
 example/gluon/lipnet/asset/s2_bbbf7p_001.png       | Bin 0 -> 36768 bytes
 example/gluon/lipnet/asset/s2_bbbf7p_074.png       | Bin 0 -> 38248 bytes
 example/gluon/lipnet/checkpoint/__init__.py        |  16 +
 example/gluon/lipnet/data_loader.py                |  94 ++++++
 example/gluon/lipnet/infer.py                      |  52 +++
 example/gluon/lipnet/main.py                       |  47 +++
 example/gluon/lipnet/models/__init__.py            |   0
 example/gluon/lipnet/models/network.py             |  73 +++++
 example/gluon/lipnet/requirements.txt              |   7 +
 example/gluon/lipnet/tests/test_beamsearch.py      |  42 +++
 example/gluon/lipnet/trainer.py                    | 232 +++++++++++++
 example/gluon/lipnet/utils/__init__.py             |  16 +
 example/gluon/lipnet/utils/align.py                |  83 +++++
 example/gluon/lipnet/utils/common.py               |  80 +++++
 example/gluon/lipnet/utils/download_data.py        | 112 +++++++
 example/gluon/lipnet/utils/multi.py                | 104 ++++++
 example/gluon/lipnet/utils/preprocess_data.py      | 262 +++++++++++++++
 example/gluon/lipnet/utils/run_preprocess.ipynb    | 194 +++++++++++
 .../utils/run_preprocess_single_process.ipynb      | 360 +++++++++++++++++++++
 27 files changed, 2141 insertions(+)

diff --git a/example/gluon/lipnet/.gitignore b/example/gluon/lipnet/.gitignore
new file mode 100644
index 0000000..9a6ee99
--- /dev/null
+++ b/example/gluon/lipnet/.gitignore
@@ -0,0 +1,3 @@
+__pycache__/
+utils/*.dat
+
diff --git a/example/gluon/lipnet/BeamSearch.py b/example/gluon/lipnet/BeamSearch.py
new file mode 100644
index 0000000..1b41bc0
--- /dev/null
+++ b/example/gluon/lipnet/BeamSearch.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Module : this module to decode using beam search
+https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/utils/CTCDecoder/BeamSearch.py 
+"""
+
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+
+class BeamEntry:
+    """
+    information about one single beam at specific time-step
+    """
+    def __init__(self):
+        self.prTotal = 0 # blank and non-blank
+        self.prNonBlank = 0 # non-blank
+        self.prBlank = 0 # blank
+        self.prText = 1 # LM score
+        self.lmApplied = False # flag if LM was already applied to this beam
+        self.labeling = () # beam-labeling
+
+class BeamState:
+    """
+    information about the beams at specific time-step
+    """
+    def __init__(self):
+        self.entries = {}
+        
+    def norm(self):
+        """
+        length-normalise LM score
+        """
+        for (k, _) in self.entries.items():
+            labelingLen = len(self.entries[k].labeling)
+            self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0))
+
+    def sort(self):
+        """
+        return beam-labelings, sorted by probability
+        """
+        beams = [v for (_, v) in self.entries.items()]
+        sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)
+        return [x.labeling for x in sortedBeams]
+
+def applyLM(parentBeam, childBeam, classes, lm):
+    """
+    calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars
+    """
+    if lm and not childBeam.lmApplied:
+        c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char
+        c2 = classes[childBeam.labeling[-1]] # second char
+        lmFactor = 0.01 # influence of language model
+        bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other
+        childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence
+        childBeam.lmApplied = True # only apply LM once per beam entry
+
+def addBeam(beamState, labeling):
+    """
+    add beam if it does not yet exist
+    """
+    if labeling not in beamState.entries:
+        beamState.entries[labeling] = BeamEntry()
+
+def ctcBeamSearch(mat, classes, lm, k, beamWidth):
+    """
+    beam search as described by the paper of Hwang et al. and the paper of Graves et al.
+    """
+
+    blankIdx = len(classes)
+    maxT, maxC = mat.shape
+
+    # initialise beam state
+    last = BeamState()
+    labeling = ()
+    last.entries[labeling] = BeamEntry()
+    last.entries[labeling].prBlank = 1
+    last.entries[labeling].prTotal = 1
+
+    # go over all time-steps
+    for t in range(maxT):
+        curr = BeamState()
+
+        # get beam-labelings of best beams
+        bestLabelings = last.sort()[0:beamWidth]
+
+	    # go over best beams
+        for labeling in bestLabelings:
+
+	        # probability of paths ending with a non-blank
+            prNonBlank = 0
+	        # in case of non-empty beam
+            if labeling:
+		       # probability of paths with repeated last char at the end
+                try: 
+                    prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]]
+                except FloatingPointError:
+                    prNonBlank = 0
+
+	    # probability of paths ending with a blank
+            prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx]
+
+	    # add beam at current time-step if needed
+            addBeam(curr, labeling)
+
+            # fill in data
+            curr.entries[labeling].labeling = labeling
+            curr.entries[labeling].prNonBlank += prNonBlank
+            curr.entries[labeling].prBlank += prBlank
+            curr.entries[labeling].prTotal += prBlank + prNonBlank
+            curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from
+            curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling
+
+            # extend current beam-labeling
+            for c in range(maxC - 1):
+                # add new char to current beam-labeling
+                newLabeling = labeling + (c,)
+
+                # if new labeling contains duplicate char at the end, only consider paths ending with a blank
+                if labeling and labeling[-1] == c:
+                    prNonBlank = mat[t, c] * last.entries[labeling].prBlank
+                else:
+                    prNonBlank = mat[t, c] * last.entries[labeling].prTotal
+
+		        # add beam at current time-step if needed
+                addBeam(curr, newLabeling)
+				
+		        # fill in data
+                curr.entries[newLabeling].labeling = newLabeling
+                curr.entries[newLabeling].prNonBlank += prNonBlank
+                curr.entries[newLabeling].prTotal += prNonBlank
+				
+		        # apply LM
+                applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm)
+
+        # set new beam state
+        last = curr
+
+    # normalise LM scores according to beam-labeling-length
+    last.norm()
+
+    # sort by probability
+    bestLabelings = last.sort()[:k] # get most probable labeling
+
+    output = []
+    for bestLabeling in bestLabelings:
+        # map labels to chars
+        res = ''
+        for l in bestLabeling:
+            res += classes[l]
+        output.append(res)
+    return output
\ No newline at end of file
diff --git a/example/gluon/lipnet/README.md b/example/gluon/lipnet/README.md
new file mode 100644
index 0000000..70eda17
--- /dev/null
+++ b/example/gluon/lipnet/README.md
@@ -0,0 +1,194 @@
+# LipNet: End-to-End Sentence-level Lipreading
+
+---
+
+Gluon implementation of [LipNet: End-to-End Sentence-level Lipreading](https://arxiv.org/abs/1611.01599)
+
+![net_structure](asset/network_structure.png)
+
+## Requirements
+- Python 3.6.4
+- MXnet 1.3.0
+- The Required Disk Space: 35Gb
+```
+pip install -r requirements.txt
+```
+
+---
+
+## The Data
+- The GRID audiovisual sentence corpus (http://spandh.dcs.shef.ac.uk/gridcorpus/)
+  - GRID is a large multitalker audiovisual sentence corpus to support joint computational-behavioral studies in speech perception. In brief, the corpus consists of high-quality audio and video (facial) recordings of 1000 sentences spoken by each of 34 talkers (18 male, 16 female). Sentences are of the form "put red at G9 now". The corpus, together with transcriptions, is freely available for research use.
+- Video: (normal)(480 M each)
+  - Each movie has one sentence consist of 6 words.
+- Align: word alignments(190 K each) 
+  - One align has 6 words. Each word has start time and end time. But this tutorial needs just sentence because of using ctc-loss.
+ 
+---
+
+## Prepare the Data
+### (1) Download the data
+- Outputs
+  - The Total Moives(mp4): 16GB
+  - The Total Aligns(text): 134MB
+- Arguments
+  - src_path : Path for videos (default='./data/mp4s/')
+  - align_path : Path for aligns (default='./data/')
+  - n_process : num of process (default=1)
+
+```
+cd ./utils && python download_data.py --n_process=$(nproc)
+```
+
+### (2) Preprocess the Data: Extracting the mouth images from a video and save it.
+
+* Using Face Landmark Detection(http://dlib.net/)
+
+#### Preprocess (preprocess_data.py)
+*  If there is no landmark, it download automatically.  
+*  Using Face Landmark Detection, It extract the mouth from a video.  
+
+- example: 
+ - video: ./data/mp4s/s2/bbbf7p.mpg
+ - align(target): ./data/align/s2/bbbf7p.align  
+     : 'sil bin blue by f seven please sil'
+
+
+- Video to the images (75 Frames)
+
+Frame 0            |  Frame 1 | ... | Frame 74 |
+:-------------------------:|:-------------------------:|:-------------------------:|:-------------------------:
+![](asset/s2_bbbf7p_000.png)  |  ![](asset/s2_bbbf7p_001.png) |  ...  |  ![](asset/s2_bbbf7p_074.png)
+
+  - Extract the mouth from images
+
+Frame 0            |  Frame 1 | ... | Frame 74 |
+:-------------------------:|:-------------------------:|:-------------------------:|:-------------------------:
+![](asset/mouth_000.png)  |  ![](asset/mouth_001.png) |  ...  |  ![](asset/mouth_074.png)
+
+* Save the result images into tgt_path.  
+
+----
+
+### How to run
+
+- Arguments
+  - src_path : Path for videos (default='./data/mp4s/')
+  - tgt_path : Path for preprocessed images (default='./data/datasets/')
+  - n_process : num of process (default=1)
+
+- Outputs
+  - The Total Images(png): 19GB
+- Elapsed time
+  - About 54 Hours using 1 process
+  - If you use the multi-processes, you can finish the number of processes faster.
+    - e.g) 9 hours using 6 processes
+
+You can run the preprocessing with just one processor, but this will take a long time (>48 hours). To use all of the available processors, use the following command: 
+
+```
+cd ./utils && python preprocess_data.py --n_process=$(nproc)
+```
+
+## Output: Data Structure
+
+```
+The training data folder should look like : 
+<train_data_root>
+                |--datasets
+                        |--s1
+                           |--bbir7s
+                               |--mouth_000.png
+                               |--mouth_001.png
+                                   ...
+                           |--bgaa8p
+                               |--mouth_000.png
+                               |--mouth_001.png
+                                  ...
+                        |--s2
+                            ...
+                 |--align
+                         |--bw1d8a.align
+                         |--bggzzs.align
+                             ...
+
+```
+
+---
+
+## Training
+
+- According to [LipNet: End-to-End Sentence-level Lipreading](https://arxiv.org/abs/1611.01599), four (S1, S2, S20, S22) of the 34 subjects are used for evaluation.
+ The other subjects are used for training.
+ 
+- To use the multi-gpu, it is recommended to make the batch size $(num_gpus) times larger.
+
+  - e.g) 1-gpu and 128 batch_size > 2-gpus 256 batch_size
+
+
+- arguments
+  - batch_size : Define batch size (default=64)
+  - epochs : Define total epochs (default=100)
+  - image_path : Path for lip image files (default='./data/datasets/')
+  - align_path : Path for align files (default='./data/align/')
+  - dr_rate : Dropout rate(default=0.5)
+  - num_gpus : Num of gpus (if num_gpus is 0, then use cpu) (default=1)
+  - num_workers : Num of workers when generating data (default=0)
+  - model_path : Path of pretrained model (defalut=None)
+  
+```
+python main.py
+```
+
+---
+
+## Test Environment
+- 72 CPU cores
+- 1 GPU (NVIDIA Tesla V100 SXM2 32 GB)
+- 128 Batch Size
+
+  -  It takes over 24 hours (60 epochs) to get some good results.
+
+---
+
+## Inference
+
+- arguments
+  - batch_size : Define batch size (default=64)
+  - image_path : Path for lip image files (default='./data/datasets/')
+  - align_path : Path for align files (default='./data/align/')
+  - num_gpus : Num of gpus (if num_gpus is 0, then use cpu) (default=1)
+  - num_workers : Num of workers when generating data (default=0)
+  - data_type : 'train' or 'valid' (defalut='valid')
+  - model_path : Path of pretrained model (defalut=None)
+    
+```
+python infer.py --model_path=$(model_path)
+```
+
+
+```
+[Target]
+['lay green with a zero again',
+ 'bin blue with r nine please',
+ 'set blue with e five again',
+ 'bin green by t seven soon',
+ 'lay red at d five now',
+ 'bin green in x eight now',
+ 'bin blue with e one now',
+ 'lay red at j nine now']
+ ```
+ 
+ ```
+[Pred]
+['lay green with s zero again',
+ 'bin blue with r nine please',
+ 'set blue with e five again',
+ 'bin green by t seven soon',
+ 'lay red at c five now',
+ 'bin green in x eight now',
+ 'bin blue with m one now',
+ 'lay red at j nine now']
+ ```
+  
+
diff --git a/example/gluon/lipnet/asset/mouth_000.png b/example/gluon/lipnet/asset/mouth_000.png
new file mode 100644
index 0000000..b318e56
Binary files /dev/null and b/example/gluon/lipnet/asset/mouth_000.png differ
diff --git a/example/gluon/lipnet/asset/mouth_001.png b/example/gluon/lipnet/asset/mouth_001.png
new file mode 100644
index 0000000..60bd04a
Binary files /dev/null and b/example/gluon/lipnet/asset/mouth_001.png differ
diff --git a/example/gluon/lipnet/asset/mouth_074.png b/example/gluon/lipnet/asset/mouth_074.png
new file mode 100644
index 0000000..e5e0d78
Binary files /dev/null and b/example/gluon/lipnet/asset/mouth_074.png differ
diff --git a/example/gluon/lipnet/asset/network_structure.png b/example/gluon/lipnet/asset/network_structure.png
new file mode 100644
index 0000000..eeec2cb
Binary files /dev/null and b/example/gluon/lipnet/asset/network_structure.png differ
diff --git a/example/gluon/lipnet/asset/s2_bbbf7p_000.png b/example/gluon/lipnet/asset/s2_bbbf7p_000.png
new file mode 100644
index 0000000..6495d2f
Binary files /dev/null and b/example/gluon/lipnet/asset/s2_bbbf7p_000.png differ
diff --git a/example/gluon/lipnet/asset/s2_bbbf7p_001.png b/example/gluon/lipnet/asset/s2_bbbf7p_001.png
new file mode 100644
index 0000000..2a7e269
Binary files /dev/null and b/example/gluon/lipnet/asset/s2_bbbf7p_001.png differ
diff --git a/example/gluon/lipnet/asset/s2_bbbf7p_074.png b/example/gluon/lipnet/asset/s2_bbbf7p_074.png
new file mode 100644
index 0000000..eabd392
Binary files /dev/null and b/example/gluon/lipnet/asset/s2_bbbf7p_074.png differ
diff --git a/example/gluon/lipnet/checkpoint/__init__.py b/example/gluon/lipnet/checkpoint/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/example/gluon/lipnet/checkpoint/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/example/gluon/lipnet/data_loader.py b/example/gluon/lipnet/data_loader.py
new file mode 100644
index 0000000..e3cc24b
--- /dev/null
+++ b/example/gluon/lipnet/data_loader.py
@@ -0,0 +1,94 @@
+"""
+Description : Set DataSet module for lip images
+"""
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import glob
+from mxnet import nd
+import mxnet.gluon.data.dataset as dataset
+from mxnet.gluon.data.vision.datasets import image
+from utils.align import Align
+
+# pylint: disable=too-many-instance-attributes, too-many-arguments
+class LipsDataset(dataset.Dataset):
+    """
+    Description : DataSet class for lip images
+    """
+    def __init__(self, root, align_root, flag=1,
+                 mode='train', transform=None, seq_len=75):
+        assert mode in ['train', 'valid']
+        self._root = os.path.expanduser(root)
+        self._align_root = align_root
+        self._flag = flag
+        self._transform = transform
+        self._exts = ['.jpg', '.jpeg', '.png']
+        self._seq_len = seq_len
+        self._mode = mode
+        self._list_images(self._root)
+
+    def _list_images(self, root):
+        """
+        Description : generate list for lip images
+        """
+        self.labels = []
+        self.items = []
+
+        valid_unseen_sub_idx = [1, 2, 20, 22]
+        skip_sub_idx = [21]
+
+        if self._mode == 'train':
+            sub_idx = ['s' + str(i) for i in range(1, 35) \
+                             if i not in valid_unseen_sub_idx + skip_sub_idx]
+        elif self._mode == 'valid':
+            sub_idx = ['s' + str(i) for i in valid_unseen_sub_idx]
+
+        folder_path = []
+        for i in sub_idx:
+            folder_path.extend(glob.glob(os.path.join(root, i, "*")))
+
+        for folder in folder_path:
+            filename = glob.glob(os.path.join(folder, "*"))
+            if len(filename) != self._seq_len:
+                continue
+            filename.sort()
+            label = os.path.split(folder)[-1]
+            self.items.append((filename, label))
+
+    def align_generation(self, file_nm, padding=75):
+        """
+        Description : Align to lip position
+        """
+        align = Align(self._align_root + '/' + file_nm + '.align')
+        return nd.array(align.sentence(padding))
+
+    def __getitem__(self, idx):
+        img = list()
+        for image_name in self.items[idx][0]:
+            tmp_img = image.imread(image_name, self._flag)
+            if self._transform is not None:
+                tmp_img = self._transform(tmp_img)
+            img.append(tmp_img)
+        img = nd.stack(*img)
+        img = nd.transpose(img, (1, 0, 2, 3))
+        label = self.align_generation(self.items[idx][1],
+                                      padding=self._seq_len)
+        return img, label
+
+    def __len__(self):
+        return len(self.items)
diff --git a/example/gluon/lipnet/infer.py b/example/gluon/lipnet/infer.py
new file mode 100644
index 0000000..746df9a
--- /dev/null
+++ b/example/gluon/lipnet/infer.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Description : main module to run the lipnet inference code
+"""
+
+
+import argparse
+from trainer import Train
+
+def main():
+    """
+    Description : run lipnet training code using argument info
+    """
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', type=int, default=64)
+    parser.add_argument('--image_path', type=str, default='./data/datasets/')
+    parser.add_argument('--align_path', type=str, default='./data/align/')
+    parser.add_argument('--num_gpus', type=int, default=1)
+    parser.add_argument('--num_workers', type=int, default=0)
+    parser.add_argument('--data_type', type=str, default='valid')
+    parser.add_argument('--model_path', type=str, default=None)
+    config = parser.parse_args()
+    trainer = Train(config)
+    trainer.build_model(path=config.model_path)
+    trainer.load_dataloader()
+
+    if config.data_type == 'train':
+        data_loader = trainer.train_dataloader
+    elif config.data_type == 'valid':
+        data_loader = trainer.valid_dataloader
+
+    trainer.infer_batch(data_loader)
+
+if __name__ == "__main__":
+    main()
+    
\ No newline at end of file
diff --git a/example/gluon/lipnet/main.py b/example/gluon/lipnet/main.py
new file mode 100644
index 0000000..8e5e756
--- /dev/null
+++ b/example/gluon/lipnet/main.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Description : main module to run the lipnet training code
+"""
+
+
+import argparse
+from trainer import Train
+
+def main():
+    """
+    Description : run lipnet training code using argument info
+    """
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', type=int, default=64)
+    parser.add_argument('--epochs', type=int, default=100)
+    parser.add_argument('--image_path', type=str, default='./data/datasets/')
+    parser.add_argument('--align_path', type=str, default='./data/align/')
+    parser.add_argument('--dr_rate', type=float, default=0.5)
+    parser.add_argument('--num_gpus', type=int, default=1)
+    parser.add_argument('--num_workers', type=int, default=0)
+    parser.add_argument('--model_path', type=str, default=None)
+    config = parser.parse_args()
+    trainer = Train(config)
+    trainer.build_model(dr_rate=config.dr_rate, path=config.model_path)
+    trainer.load_dataloader()
+    trainer.run(epochs=config.epochs)
+
+if __name__ == "__main__":
+    main()
+    
\ No newline at end of file
diff --git a/example/gluon/lipnet/models/__init__.py b/example/gluon/lipnet/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/example/gluon/lipnet/models/network.py b/example/gluon/lipnet/models/network.py
new file mode 100644
index 0000000..b8f005a
--- /dev/null
+++ b/example/gluon/lipnet/models/network.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Description : LipNet module using gluon
+"""
+
+from mxnet.gluon import nn, rnn
+# pylint: disable=too-many-instance-attributes
+class LipNet(nn.HybridBlock):
+    """
+    Description : LipNet network using gluon
+    dr_rate : Dropout rate
+    """
+    def __init__(self, dr_rate, **kwargs):
+        super(LipNet, self).__init__(**kwargs)
+        with self.name_scope():
+            self.conv1 = nn.Conv3D(32, kernel_size=(3, 5, 5), strides=(1, 2, 2), padding=(1, 2, 2))
+            self.bn1 = nn.InstanceNorm(in_channels=32)
+            self.dr1 = nn.Dropout(dr_rate, axes=(1, 2))
+            self.pool1 = nn.MaxPool3D((1, 2, 2), (1, 2, 2))
+            self.conv2 = nn.Conv3D(64, kernel_size=(3, 5, 5), strides=(1, 1, 1), padding=(1, 2, 2))
+            self.bn2 = nn.InstanceNorm(in_channels=64)
+            self.dr2 = nn.Dropout(dr_rate, axes=(1, 2))
+            self.pool2 = nn.MaxPool3D((1, 2, 2), (1, 2, 2))
+            self.conv3 = nn.Conv3D(96, kernel_size=(3, 3, 3), strides=(1, 1, 1), padding=(1, 2, 2))
+            self.bn3 = nn.InstanceNorm(in_channels=96)
+            self.dr3 = nn.Dropout(dr_rate, axes=(1, 2))
+            self.pool3 = nn.MaxPool3D((1, 2, 2), (1, 2, 2))
+            self.gru1 = rnn.GRU(256, bidirectional=True)
+            self.gru2 = rnn.GRU(256, bidirectional=True)
+            self.dense = nn.Dense(27+1, flatten=False)
+
+    # pylint: disable=arguments-differ
+    def hybrid_forward(self, F, x):
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = F.relu(out)
+        out = self.dr1(out)
+        out = self.pool1(out)
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = F.relu(out)
+        out = self.dr2(out)
+        out = self.pool2(out)
+        out = self.conv3(out)
+        out = self.bn3(out)
+        out = F.relu(out)
+        out = self.dr3(out)
+        out = self.pool3(out)
+        out = F.transpose(out, (2, 0, 1, 3, 4))
+        # pylint: disable=no-member
+        out = out.reshape((0, 0, -1))
+        out = self.gru1(out)
+        out = self.gru2(out)
+        out = self.dense(out)
+        out = F.log_softmax(out, axis=2)
+        out = F.transpose(out, (1, 0, 2))
+        return out
diff --git a/example/gluon/lipnet/requirements.txt b/example/gluon/lipnet/requirements.txt
new file mode 100644
index 0000000..f1fcda3
--- /dev/null
+++ b/example/gluon/lipnet/requirements.txt
@@ -0,0 +1,7 @@
+dlib==19.15.0
+Pillow==4.1.0
+scipy==0.19.0
+scikit-image==0.13.1
+scikit-video==1.1.11
+sk-video==1.1.10
+tqdm
diff --git a/example/gluon/lipnet/tests/test_beamsearch.py b/example/gluon/lipnet/tests/test_beamsearch.py
new file mode 100644
index 0000000..069cbae
--- /dev/null
+++ b/example/gluon/lipnet/tests/test_beamsearch.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""it is the test for the decode using beam search
+Ref:
+https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/utils/CTCDecoder/BeamSearch.py
+"""
+
+import unittest
+import numpy as np
+from BeamSearch import ctcBeamSearch
+
+class TestBeamSearch(unittest.TestCase):
+    """Test Beam Search
+    """
+    def test_ctc_beam_search(self):
+        "test decoder"
+        classes = 'ab'
+        mat = np.array([[0.4, 0, 0.6], [0.4, 0, 0.6]])
+        print('Test beam search')
+        expected = 'a'
+        actual = ctcBeamSearch(mat, classes, None, k=2, beamWidth=3)[0]
+        print('Expected: "' + expected + '"')
+        print('Actual: "' + actual + '"')
+        self.assertEqual(expected, actual)
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/example/gluon/lipnet/trainer.py b/example/gluon/lipnet/trainer.py
new file mode 100644
index 0000000..df5c86e
--- /dev/null
+++ b/example/gluon/lipnet/trainer.py
@@ -0,0 +1,232 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Description : Training module for LipNet
+"""
+
+
+import sys
+import mxnet as mx
+from mxnet import gluon, autograd, nd
+from mxnet.gluon.data.vision import transforms
+from tqdm import tqdm, trange
+from data_loader import LipsDataset
+from models.network import LipNet
+from BeamSearch import ctcBeamSearch
+from utils.common import char_conv, int2char
+# set gpu count
+
+
+def setting_ctx(num_gpus):
+    """
+    Description : set gpu module
+    """
+    if num_gpus > 0:
+        ctx = [mx.gpu(i) for i in range(num_gpus)]
+    else:
+        ctx = [mx.cpu()]
+    return ctx
+
+
+ALPHABET = ''
+for i in range(27):
+    ALPHABET += int2char(i)
+
+def char_beam_search(out):
+    """
+    Description : apply beam search for prediction result
+    """
+    out_conv = list()
+    for idx in range(out.shape[0]):
+        probs = out[idx]
+        prob = probs.softmax().asnumpy()
+        line_string_proposals = ctcBeamSearch(prob, ALPHABET, None, k=4, beamWidth=25)
+        out_conv.append(line_string_proposals[0])
+    return out_conv
+
+# pylint: disable=too-many-instance-attributes, too-many-locals
+class Train:
+    """
+    Description : Train class for training network
+    """
+    def __init__(self, config):
+        ##setting hyper-parameters
+        self.batch_size = config.batch_size
+        self.image_path = config.image_path
+        self.align_path = config.align_path
+        self.num_gpus = config.num_gpus
+        self.ctx = setting_ctx(self.num_gpus)
+        self.num_workers = config.num_workers
+        self.seq_len = 75
+
+    def build_model(self, dr_rate=0, path=None):
+        """
+        Description : build network
+        """
+        #set network
+        self.net = LipNet(dr_rate)
+        self.net.hybridize()
+        self.net.initialize(ctx=self.ctx)
+
+        if path is not None:
+            self.load_model(path)
+
+        #set optimizer
+        self.loss_fn = gluon.loss.CTCLoss()
+        self.trainer = gluon.Trainer(self.net.collect_params(), \
+                                     optimizer='SGD')
+
+    def save_model(self, epoch, loss):
+        """
+        Description : save parameter of network weight
+        """
+        prefix = 'checkpoint/epoches'
+        file_name = "{prefix}_{epoch}_loss_{l:.4f}".format(prefix=prefix,
+                                                           epoch=str(epoch),
+                                                           l=loss)
+        self.net.save_parameters(file_name)
+
+    def load_model(self, path=''):
+        """
+        Description : load parameter of network weight
+        """
+        self.net.load_parameters(path)
+
+    def load_dataloader(self):
+        """
+        Description : Setup the dataloader
+        """
+
+        input_transform = transforms.Compose([transforms.ToTensor(), \
+                                             transforms.Normalize((0.7136, 0.4906, 0.3283), \
+                                                                  (0.1138, 0.1078, 0.0917))])
+        training_dataset = LipsDataset(self.image_path,
+                                       self.align_path,
+                                       mode='train',
+                                       transform=input_transform,
+                                       seq_len=self.seq_len)
+
+        self.train_dataloader = mx.gluon.data.DataLoader(training_dataset,
+                                                         batch_size=self.batch_size,
+                                                         shuffle=True,
+                                                         num_workers=self.num_workers)
+
+        valid_dataset = LipsDataset(self.image_path,
+                                    self.align_path,
+                                    mode='valid',
+                                    transform=input_transform,
+                                    seq_len=self.seq_len)
+
+        self.valid_dataloader = mx.gluon.data.DataLoader(valid_dataset,
+                                                         batch_size=self.batch_size,
+                                                         shuffle=True,
+                                                         num_workers=self.num_workers)
+
+    def train(self, data, label, batch_size):
+        """
+        Description : training for LipNet
+        """
+        # pylint: disable=no-member
+        sum_losses = 0
+        len_losses = 0
+        with autograd.record():
+            losses = [self.loss_fn(self.net(X), Y) for X, Y in zip(data, label)]
+        for loss in losses:
+            sum_losses += mx.nd.array(loss).sum().asscalar()
+            len_losses += len(loss)
+            loss.backward()
+        self.trainer.step(batch_size)
+        return sum_losses, len_losses
+
+    def infer(self, input_data, input_label):
+        """
+        Description : Print sentence for prediction result
+        """
+        sum_losses = 0
+        len_losses = 0
+        for data, label in zip(input_data, input_label):
+            pred = self.net(data)
+            sum_losses += mx.nd.array(self.loss_fn(pred, label)).sum().asscalar()
+            len_losses += len(data)
+            pred_convert = char_beam_search(pred)
+            label_convert = char_conv(label.asnumpy())
+            for target, pred in zip(label_convert, pred_convert):
+                print("target:{t}  pred:{p}".format(t=target, p=pred))
+        return sum_losses, len_losses
+
+    def train_batch(self, dataloader):
+        """
+        Description : training for LipNet
+        """
+        sum_losses = 0
+        len_losses = 0
+        for input_data, input_label in tqdm(dataloader):
+            data = gluon.utils.split_and_load(input_data, self.ctx, even_split=False)
+            label = gluon.utils.split_and_load(input_label, self.ctx, even_split=False)
+            batch_size = input_data.shape[0]
+            sum_losses, len_losses = self.train(data, label, batch_size)
+            sum_losses += sum_losses
+            len_losses += len_losses
+
+        return sum_losses, len_losses
+
+    def infer_batch(self, dataloader):
+        """
+        Description : inference for LipNet
+        """
+        sum_losses = 0
+        len_losses = 0
+        for input_data, input_label in dataloader:
+            data = gluon.utils.split_and_load(input_data, self.ctx, even_split=False)
+            label = gluon.utils.split_and_load(input_label, self.ctx, even_split=False)
+            sum_losses, len_losses = self.infer(data, label)
+            sum_losses += sum_losses
+            len_losses += len_losses
+
+        return sum_losses, len_losses
+
+    def run(self, epochs):
+        """
+        Description : Run training for LipNet
+        """
+        best_loss = sys.maxsize
+        for epoch in trange(epochs):
+            iter_no = 0
+
+            ## train
+            sum_losses, len_losses = self.train_batch(self.train_dataloader)
+
+            if iter_no % 20 == 0:
+                current_loss = sum_losses / len_losses
+                print("[Train] epoch:{e} iter:{i} loss:{l:.4f}".format(e=epoch,
+                                                                       i=iter_no,
+                                                                       l=current_loss))
+
+            ## validating
+            sum_val_losses, len_val_losses = self.infer_batch(self.valid_dataloader)
+
+            current_val_loss = sum_val_losses / len_val_losses
+            print("[Vaild] epoch:{e} iter:{i} loss:{l:.4f}".format(e=epoch,
+                                                                   i=iter_no,
+                                                                   l=current_val_loss))
+
+            if best_loss > current_val_loss:
+                self.save_model(epoch, current_val_loss)
+                best_loss = current_val_loss
+
+            iter_no += 1
diff --git a/example/gluon/lipnet/utils/__init__.py b/example/gluon/lipnet/utils/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/example/gluon/lipnet/utils/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/example/gluon/lipnet/utils/align.py b/example/gluon/lipnet/utils/align.py
new file mode 100644
index 0000000..48d0716
--- /dev/null
+++ b/example/gluon/lipnet/utils/align.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Module: align
+This is used when the data is genrated by LipsDataset
+"""
+
+import numpy as np
+from .common import word_to_vector
+
+
+class Align(object):
+    """
+    Preprocess for Align
+    """
+    skip_list = ['sil', 'sp']
+
+    def __init__(self, align_path):
+        self.build(align_path)
+
+    def build(self, align_path):
+        """
+        Build the align array
+        """
+        file = open(align_path, 'r')
+        lines = file.readlines()
+        file.close()
+        # words: list([op, ed, word])
+        words = []
+        for line in lines:
+            _op, _ed, word = line.strip().split(' ')
+            if word not in Align.skip_list:
+                words.append((int(_op), int(_ed), word))
+        self.words = words
+        self.n_words = len(words)
+        self.sentence_str = " ".join([w[2] for w in self.words])
+        self.sentence_length = len(self.sentence_str)
+
+    def sentence(self, padding=75):
+        """
+        Get sentence
+        """
+        vec = word_to_vector(self.sentence_str)
+        vec += [-1] * (padding - self.sentence_length)
+        return np.array(vec, dtype=np.int32)
+
+    def word(self, _id, padding=75):
+        """
+        Get words
+        """
+        word = self.words[_id][2]
+        vec = word_to_vector(word)
+        vec += [-1] * (padding - len(vec))
+        return np.array(vec, dtype=np.int32)
+
+    def word_length(self, _id):
+        """
+        Get the length of words
+        """
+        return len(self.words[_id][2])
+
+    def word_frame_pos(self, _id):
+        """
+        Get the position of words
+        """
+        left = int(self.words[_id][0]/1000)
+        right = max(left+1, int(self.words[_id][1]/1000))
+        return (left, right)
diff --git a/example/gluon/lipnet/utils/common.py b/example/gluon/lipnet/utils/common.py
new file mode 100644
index 0000000..ec96b68
--- /dev/null
+++ b/example/gluon/lipnet/utils/common.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Module: This module contains common conversion functions
+
+"""
+
+
+def char2int(char):
+    """
+    Convert character to integer.
+    """
+    if char >= 'a' and char <= 'z':
+        return ord(char) - ord('a')
+    elif char == ' ':
+        return 26
+    return None
+
+
+def int2char(num):
+    """
+    Convert integer to character.
+    """
+    if num >= 0 and num < 26:
+        return chr(num + ord('a'))
+    elif num == 26:
+        return ' '
+    return None
+
+
+def word_to_vector(word):
+    """
+    Convert character vectors to integer vectors.
+    """
+    vector = []
+    for char in list(word):
+        vector.append(char2int(char))
+    return vector
+
+
+def vector_to_word(vector):
+    """
+    Convert integer vectors to character vectors.
+    """
+    word = ""
+    for vec in vector:
+        word = word + int2char(vec)
+    return word
+
+
+def char_conv(out):
+    """
+    Convert integer vectors to character vectors for batch.
+    """
+    out_conv = list()
+    for i in range(out.shape[0]):
+        tmp_str = ''
+        for j in range(out.shape[1]):
+            if int(out[i][j]) >= 0:
+                tmp_char = int2char(int(out[i][j]))
+                if int(out[i][j]) == 27:
+                    tmp_char = ''
+                tmp_str = tmp_str + tmp_char
+        out_conv.append(tmp_str)
+    return out_conv
diff --git a/example/gluon/lipnet/utils/download_data.py b/example/gluon/lipnet/utils/download_data.py
new file mode 100644
index 0000000..3051eb2
--- /dev/null
+++ b/example/gluon/lipnet/utils/download_data.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Module: download_data
+This module provides utilities for downloading the datasets for training LipNet
+"""
+
+import os
+from os.path import exists
+from multi import multi_p_run, put_worker
+
+
+def download_mp4(from_idx, to_idx, _params):
+    """
+    download mp4s
+    """
+    succ = set()
+    fail = set()
+    for idx in range(from_idx, to_idx):
+        name = 's' + str(idx)
+        save_folder = '{src_path}/{nm}'.format(src_path=_params['src_path'], nm=name)
+        if idx == 0 or os.path.isdir(save_folder):
+            continue
+        script = "http://spandh.dcs.shef.ac.uk/gridcorpus/{nm}/video/{nm}.mpg_vcd.zip".format( \
+                    nm=name)
+        down_sc = 'cd {src_path} && curl {script} --output {nm}.mpg_vcd.zip && \
+                    unzip {nm}.mpg_vcd.zip'.format(script=script,
+                                                   nm=name,
+                                                   src_path=_params['src_path'])
+        try:
+            print(down_sc)
+            os.system(down_sc)
+            succ.add(idx)
+        except OSError as error:
+            print(error)
+            fail.add(idx)
+    return (succ, fail)
+
+
+def download_align(from_idx, to_idx, _params):
+    """
+    download aligns
+    """
+    succ = set()
+    fail = set()
+    for idx in range(from_idx, to_idx):
+        name = 's' + str(idx)
+        if idx == 0:
+            continue
+        script = "http://spandh.dcs.shef.ac.uk/gridcorpus/{nm}/align/{nm}.tar".format(nm=name)
+        down_sc = 'cd {align_path} && wget {script} && \
+                    tar -xvf {nm}.tar'.format(script=script,
+                                              nm=name,
+                                              align_path=_params['align_path'])
+        try:
+            print(down_sc)
+            os.system(down_sc)
+            succ.add(idx)
+        except OSError as error:
+            print(error)
+            fail.add(idx)
+    return (succ, fail)
+
+
+if __name__ == '__main__':
+    import argparse
+    PARSER = argparse.ArgumentParser()
+    PARSER.add_argument('--src_path', type=str, default='../data/mp4s')
+    PARSER.add_argument('--align_path', type=str, default='../data')
+    PARSER.add_argument('--n_process', type=int, default=1)
+    CONFIG = PARSER.parse_args()
+    PARAMS = {'src_path': CONFIG.src_path, 'align_path': CONFIG.align_path}
+    N_PROCESS = CONFIG.n_process
+
+    if exists('./shape_predictor_68_face_landmarks.dat') is False:
+        os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 && \
+                  bzip2 -d shape_predictor_68_face_landmarks.dat.bz2')
+
+    os.makedirs('{src_path}'.format(src_path=PARAMS['src_path']), exist_ok=True)
+    os.makedirs('{align_path}'.format(align_path=PARAMS['align_path']), exist_ok=True)
+
+    if N_PROCESS == 1:
+        RES = download_mp4(0, 35, PARAMS)
+        RES = download_align(0, 35, PARAMS)
+    else:
+        # download movie files
+        RES = multi_p_run(tot_num=35, _func=put_worker, worker=download_mp4, \
+                          params=PARAMS, n_process=N_PROCESS)
+
+        # download align files
+        RES = multi_p_run(tot_num=35, _func=put_worker, worker=download_align, \
+                          params=PARAMS, n_process=N_PROCESS)
+
+    os.system('rm -f {src_path}/*.zip && rm -f {src_path}/*/Thumbs.db'.format( \
+              src_path=PARAMS['src_path']))
+    os.system('rm -f {align_path}/*.tar && rm -f {align_path}/Thumbs.db'.format( \
+              align_path=PARAMS['align_path']))
diff --git a/example/gluon/lipnet/utils/multi.py b/example/gluon/lipnet/utils/multi.py
new file mode 100644
index 0000000..ce545b5
--- /dev/null
+++ b/example/gluon/lipnet/utils/multi.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Module: preprocess with multi-process
+"""
+
+
+def multi_p_run(tot_num, _func, worker, params, n_process):
+    """
+    Run _func with multi-process using params.
+    """
+    from multiprocessing import Process, Queue
+    out_q = Queue()
+    procs = []
+
+    split_num = split_seq(list(range(0, tot_num)), n_process)
+
+    print(tot_num, ">>", split_num)
+
+    split_len = len(split_num)
+    if n_process > split_len:
+        n_process = split_len
+
+    for i in range(n_process):
+        _p = Process(target=_func,
+                     args=(worker, split_num[i][0], split_num[i][1],
+                           params, out_q))
+        _p.daemon = True
+        procs.append(_p)
+        _p.start()
+
+    try:
+        result = []
+        for i in range(n_process):
+            result.append(out_q.get())
+        for i in procs:
+            i.join()
+    except KeyboardInterrupt:
+        print('Killing all the children in the pool.')
+        for i in procs:
+            i.terminate()
+            i.join()
+        return -1
+
+    while not out_q.empty():
+        print(out_q.get(block=False))
+
+    return result
+
+
+def split_seq(sam_num, n_tile):
+    """
+    Split the number(sam_num) into numbers by n_tile
+    """
+    import math
+    print(sam_num)
+    print(n_tile)
+    start_num = sam_num[0::int(math.ceil(len(sam_num) / (n_tile)))]
+    end_num = start_num[1::]
+    end_num.append(len(sam_num))
+    return [[i, j] for i, j in zip(start_num, end_num)]
+
+
+def put_worker(func, from_idx, to_idx, params, out_q):
+    """
+    put worker
+    """
+    succ, fail = func(from_idx, to_idx, params)
+    return out_q.put({'succ': succ, 'fail': fail})
+
+
+def test_worker(from_idx, to_idx, params):
+    """
+    the worker to test multi-process
+    """
+    params = params
+    succ = set()
+    fail = set()
+    for idx in range(from_idx, to_idx):
+        try:
+            succ.add(idx)
+        except ValueError:
+            fail.add(idx)
+    return (succ, fail)
+
+
+if __name__ == '__main__':
+    RES = multi_p_run(35, put_worker, test_worker, params={}, n_process=5)
+    print(RES)
diff --git a/example/gluon/lipnet/utils/preprocess_data.py b/example/gluon/lipnet/utils/preprocess_data.py
new file mode 100644
index 0000000..a13fad8
--- /dev/null
+++ b/example/gluon/lipnet/utils/preprocess_data.py
@@ -0,0 +1,262 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Module: preprocess_data
+Reference: https://github.com/rizkiarm/LipNet
+"""
+
+# pylint: disable=too-many-locals, no-self-use, c-extension-no-member
+
+import os
+import fnmatch
+import errno
+import numpy as np
+from scipy import ndimage
+from scipy.misc import imresize
+from skimage import io
+import skvideo.io
+import dlib
+
+def mkdir_p(path):
+    """
+    Make a directory
+    """
+    try:
+        os.makedirs(path)
+    except OSError as exc:  # Python >2.5
+        if exc.errno == errno.EEXIST and os.path.isdir(path):
+            pass
+        else:
+            raise
+
+def find_files(directory, pattern):
+    """
+    Find files
+    """
+    for root, _, files in os.walk(directory):
+        for basename in files:
+            if fnmatch.fnmatch(basename, pattern):
+                filename = os.path.join(root, basename)
+                yield filename
+
+class Video(object):
+    """
+    Preprocess for Video
+    """
+    def __init__(self, vtype='mouth', face_predictor_path=None):
+        if vtype == 'face' and face_predictor_path is None:
+            raise AttributeError('Face video need to be accompanied with face predictor')
+        self.face_predictor_path = face_predictor_path
+        self.vtype = vtype
+        self.face = None
+        self.mouth = None
+        self.data = None
+        self.length = None
+
+    def from_frames(self, path):
+        """
+        Read from frames
+        """
+        frames_path = sorted([os.path.join(path, x) for x in os.listdir(path)])
+        frames = [ndimage.imread(frame_path) for frame_path in frames_path]
+        self.handle_type(frames)
+        return self
+
+    def from_video(self, path):
+        """
+        Read from videos
+        """
+        frames = self.get_video_frames(path)
+        self.handle_type(frames)
+        return self
+
+    def from_array(self, frames):
+        """
+        Read from array
+        """
+        self.handle_type(frames)
+        return self
+
+    def handle_type(self, frames):
+        """
+        Config video types
+        """
+        if self.vtype == 'mouth':
+            self.process_frames_mouth(frames)
+        elif self.vtype == 'face':
+            self.process_frames_face(frames)
+        else:
+            raise Exception('Video type not found')
+
+    def process_frames_face(self, frames):
+        """
+        Preprocess from frames using face detector
+        """
+        detector = dlib.get_frontal_face_detector()
+        predictor = dlib.shape_predictor(self.face_predictor_path)
+        mouth_frames = self.get_frames_mouth(detector, predictor, frames)
+        self.face = np.array(frames)
+        self.mouth = np.array(mouth_frames)
+        if mouth_frames[0] is not None:
+            self.set_data(mouth_frames)
+
+    def process_frames_mouth(self, frames):
+        """
+        Preprocess from frames using mouth detector
+        """
+        self.face = np.array(frames)
+        self.mouth = np.array(frames)
+        self.set_data(frames)
+
+    def get_frames_mouth(self, detector, predictor, frames):
+        """
+        Get frames using mouth crop
+        """
+        mouth_width = 100
+        mouth_height = 50
+        horizontal_pad = 0.19
+        normalize_ratio = None
+        mouth_frames = []
+        for frame in frames:
+            dets = detector(frame, 1)
+            shape = None
+            for det in dets:
+                shape = predictor(frame, det)
+                i = -1
+            if shape is None: # Detector doesn't detect face, just return None
+                return [None]
+            mouth_points = []
+            for part in shape.parts():
+                i += 1
+                if i < 48: # Only take mouth region
+                    continue
+                mouth_points.append((part.x, part.y))
+            np_mouth_points = np.array(mouth_points)
+
+            mouth_centroid = np.mean(np_mouth_points[:, -2:], axis=0)
+
+            if normalize_ratio is None:
+                mouth_left = np.min(np_mouth_points[:, :-1]) * (1.0 - horizontal_pad)
+                mouth_right = np.max(np_mouth_points[:, :-1]) * (1.0 + horizontal_pad)
+
+                normalize_ratio = mouth_width / float(mouth_right - mouth_left)
+
+            new_img_shape = (int(frame.shape[0] * normalize_ratio),
+                             int(frame.shape[1] * normalize_ratio))
+            resized_img = imresize(frame, new_img_shape)
+
+            mouth_centroid_norm = mouth_centroid * normalize_ratio
+
+            mouth_l = int(mouth_centroid_norm[0] - mouth_width / 2)
+            mouth_r = int(mouth_centroid_norm[0] + mouth_width / 2)
+            mouth_t = int(mouth_centroid_norm[1] - mouth_height / 2)
+            mouth_b = int(mouth_centroid_norm[1] + mouth_height / 2)
+
+            mouth_crop_image = resized_img[mouth_t:mouth_b, mouth_l:mouth_r]
+
+            mouth_frames.append(mouth_crop_image)
+        return mouth_frames
+
+    def get_video_frames(self, path):
+        """
+        Get video frames
+        """
+        videogen = skvideo.io.vreader(path)
+        frames = np.array([frame for frame in videogen])
+        return frames
+
+    def set_data(self, frames):
+        """
+        Prepare the input of model
+        """
+        data_frames = []
+        for frame in frames:
+            #frame H x W x C
+            frame = frame.swapaxes(0, 1) # swap width and height to form format W x H x C
+            if len(frame.shape) < 3:
+                frame = np.array([frame]).swapaxes(0, 2).swapaxes(0, 1) # Add grayscale channel
+            data_frames.append(frame)
+        frames_n = len(data_frames)
+        data_frames = np.array(data_frames) # T x W x H x C
+        data_frames = np.rollaxis(data_frames, 3) # C x T x W x H
+        data_frames = data_frames.swapaxes(2, 3) # C x T x H x W  = NCDHW
+
+        self.data = data_frames
+        self.length = frames_n
+
+def preprocess(from_idx, to_idx, _params):
+    """
+    Preprocess: Convert a video into the mouth images
+    """
+    source_exts = '*.mpg'
+    src_path = _params['src_path']
+    tgt_path = _params['tgt_path']
+    face_predictor_path = './shape_predictor_68_face_landmarks.dat'
+
+    succ = set()
+    fail = set()
+    for idx in range(from_idx, to_idx):
+        s_id = 's' + str(idx) + '/'
+        source_path = src_path + '/' + s_id
+        target_path = tgt_path + '/' + s_id
+        fail_cnt = 0
+        for filepath in find_files(source_path, source_exts):
+            print("Processing: {}".format(filepath))
+            filepath_wo_ext = os.path.splitext(filepath)[0].split('/')[-2:]
+            target_dir = os.path.join(tgt_path, '/'.join(filepath_wo_ext))
+
+            if os.path.exists(target_dir):
+                continue
+
+            try:
+                video = Video(vtype='face', \
+                                face_predictor_path=face_predictor_path).from_video(filepath)
+                mkdir_p(target_dir)
+                i = 0
+                if video.mouth[0] is None:
+                    continue
+                for frame in video.mouth:
+                    io.imsave(os.path.join(target_dir, "mouth_{0:03d}.png".format(i)), frame)
+                    i += 1
+            except ValueError as error:
+                print(error)
+                fail_cnt += 1
+        if fail_cnt == 0:
+            succ.add(idx)
+        else:
+            fail.add(idx)
+    return (succ, fail)
+
+if __name__ == '__main__':
+    import argparse
+    from multi import multi_p_run, put_worker
+    PARSER = argparse.ArgumentParser()
+    PARSER.add_argument('--src_path', type=str, default='../data/mp4s')
+    PARSER.add_argument('--tgt_path', type=str, default='../data/datasets')
+    PARSER.add_argument('--n_process', type=int, default=1)
+    CONFIG = PARSER.parse_args()
+    N_PROCESS = CONFIG.n_process
+    PARAMS = {'src_path':CONFIG.src_path,
+              'tgt_path':CONFIG.tgt_path}
+
+    os.makedirs('{tgt_path}'.format(tgt_path=PARAMS['tgt_path']), exist_ok=True)
+
+    if N_PROCESS == 1:
+        RES = preprocess(0, 35, PARAMS)
+    else:
+        RES = multi_p_run(35, put_worker, preprocess, PARAMS, N_PROCESS)
diff --git a/example/gluon/lipnet/utils/run_preprocess.ipynb b/example/gluon/lipnet/utils/run_preprocess.ipynb
new file mode 100644
index 0000000..7a25e9b
--- /dev/null
+++ b/example/gluon/lipnet/utils/run_preprocess.ipynb
@@ -0,0 +1,194 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "from download_data import multi_p_run, put_worker, _worker, download_mp4, download_align"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## TEST"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]\n",
+      "5\n",
+      "35 >> [[0, 7], [7, 14], [14, 21], [21, 28], [28, 35]]\n",
+      "[{'succ': {0, 1, 2, 3, 4, 5, 6}, 'fail': set()}, {'succ': {7, 8, 9, 10, 11, 12, 13}, 'fail': set()}, {'succ': {14, 15, 16, 17, 18, 19, 20}, 'fail': set()}, {'succ': {21, 22, 23, 24, 25, 26, 27}, 'fail': set()}, {'succ': {32, 33, 34, 28, 29, 30, 31}, 'fail': set()}]\n"
+     ]
+    }
+   ],
+   "source": [
+    "res = multi_p_run(35, put_worker, _worker, 5)\n",
+    "print (res)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Download Data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## down\n",
+    "import os\n",
+    "os.makedirs('./datasets', exist_ok=True)\n",
+    "#os.system('rm -rf ./datasets/*')\n",
+    "\n",
+    "res = multi_p_run(35, put_worker, download_align, 9)\n",
+    "print (res)\n",
+    "\n",
+    "os.system('rm -f datasets/*.tar && rm -f datasets/align/Thumbs.db')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "res = multi_p_run(35, put_worker, download_mp4, 9)\n",
+    "print (res)\n",
+    "\n",
+    "os.system('rm -f datasets/*.zip && rm -f datasets/*/Thumbs.db')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## download single 22 th dir\n",
+    "#download_data.py(22, 22)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Preprocess Data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from preprocess_data import preprocess, find_files, Video"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "0"
+      ]
+     },
+     "execution_count": 4,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import os\n",
+    "os.makedirs('./TARGET', exist_ok=True)\n",
+    "os.system('rm -rf ./TARGET/*')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]\n",
+      "9\n",
+      "35 >> [[0, 4], [4, 8], [8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 32], [32, 35]]\n",
+      "Processing: datasets/s1/prwq3s.mpg\n",
+      "Processing: datasets/s4/lrix7n.mpg\n",
+      "Processing: datasets/s8/pgbyza.mpg\n",
+      "Processing: datasets/s12/brik7n.mpg\n",
+      "Processing: datasets/s16/sgit7p.mpg\n",
+      "Processing: datasets/s20/lrbp8a.mpg\n",
+      "Processing: datasets/s24/sbik8a.mpg\n",
+      "Processing: datasets/s28/srwf8a.mpg\n",
+      "Processing: datasets/s32/pbbm1n.mpg\n",
+      "Processing: datasets/s12/sbbaza.mpg\n",
+      "Processing: datasets/s28/lbit7n.mpg\n",
+      "Processing: datasets/s32/pbwm7p.mpg\n",
+      "Processing: datasets/s8/bril2s.mpg\n",
+      "Processing: datasets/s20/bway7n.mpg\n",
+      "Processing: datasets/s1/pbib8p.mpg\n",
+      "Processing: datasets/s16/lwaj7n.mpg\n",
+      "Processing: datasets/s24/bwwl6a.mpg\n",
+      "Processing: datasets/s4/bbwf7n.mpg\n"
+     ]
+    }
+   ],
+   "source": [
+    "res = multi_p_run(35, put_worker, preprocess, 9)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/example/gluon/lipnet/utils/run_preprocess_single_process.ipynb b/example/gluon/lipnet/utils/run_preprocess_single_process.ipynb
new file mode 100644
index 0000000..4311323
--- /dev/null
+++ b/example/gluon/lipnet/utils/run_preprocess_single_process.ipynb
@@ -0,0 +1,360 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "from download_data import multi_p_run, put_worker, test_worker, download_mp4, download_align"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tot_movies=35"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## TEST"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]\n",
+      "5\n",
+      "35 >> [[0, 7], [7, 14], [14, 21], [21, 28], [28, 35]]\n",
+      "[{'succ': {0, 1, 2, 3, 4, 5, 6}, 'fail': set()}, {'succ': {7, 8, 9, 10, 11, 12, 13}, 'fail': set()}, {'succ': {14, 15, 16, 17, 18, 19, 20}, 'fail': set()}, {'succ': {21, 22, 23, 24, 25, 26, 27}, 'fail': set()}, {'succ': {32, 33, 34, 28, 29, 30, 31}, 'fail': set()}]\n"
+     ]
+    }
+   ],
+   "source": [
+    "res = multi_p_run(tot_movies, put_worker, test_worker, params={}, n_process=5)\n",
+    "print (res)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Download Data"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Aligns"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s0/align/s0.tar && tar -xvf s0.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s1/align/s1.tar && tar -xvf s1.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s2/align/s2.tar && tar -xvf s2.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s3/align/s3.tar && tar -xvf s3.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s4/align/s4.tar && tar -xvf s4.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s5/align/s5.tar && tar -xvf s5.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s6/align/s6.tar && tar -xvf s6.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s7/align/s7.tar && tar -xvf s7.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s8/align/s8.tar && tar -xvf s8.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s9/align/s9.tar && tar -xvf s9.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s10/align/s10.tar && tar -xvf s10.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s11/align/s11.tar && tar -xvf s11.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s12/align/s12.tar && tar -xvf s12.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s13/align/s13.tar && tar -xvf s13.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s14/align/s14.tar && tar -xvf s14.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s15/align/s15.tar && tar -xvf s15.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s16/align/s16.tar && tar -xvf s16.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s17/align/s17.tar && tar -xvf s17.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s18/align/s18.tar && tar -xvf s18.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s19/align/s19.tar && tar -xvf s19.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s20/align/s20.tar && tar -xvf s20.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s21/align/s21.tar && tar -xvf s21.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s22/align/s22.tar && tar -xvf s22.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s23/align/s23.tar && tar -xvf s23.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s24/align/s24.tar && tar -xvf s24.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s25/align/s25.tar && tar -xvf s25.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s26/align/s26.tar && tar -xvf s26.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s27/align/s27.tar && tar -xvf s27.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s28/align/s28.tar && tar -xvf s28.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s29/align/s29.tar && tar -xvf s29.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s30/align/s30.tar && tar -xvf s30.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s31/align/s31.tar && tar -xvf s31.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s32/align/s32.tar && tar -xvf s32.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s33/align/s33.tar && tar -xvf s33.tar\n",
+      "cd ../data/align && wget http://spandh.dcs.shef.ac.uk/gridcorpus/s34/align/s34.tar && tar -xvf s34.tar\n"
+     ]
+    }
+   ],
+   "source": [
+    "align_path = '../data/align'\n",
+    "os.makedirs(align_path, exist_ok=True)\n",
+    "\n",
+    "res = download_align(0, tot_movies, {'align_path':align_path})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}, set())\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "0"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "print (res)\n",
+    "os.system('rm -f {align_path}/*.tar && rm -f {align_path}/Thumbs.db'.format(align_path=align_path))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Moives(MP4s)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s0/video/s0.mpg_vcd.zip --output s0.mpg_vcd.zip && unzip s0.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s1/video/s1.mpg_vcd.zip --output s1.mpg_vcd.zip && unzip s1.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s2/video/s2.mpg_vcd.zip --output s2.mpg_vcd.zip && unzip s2.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s3/video/s3.mpg_vcd.zip --output s3.mpg_vcd.zip && unzip s3.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s4/video/s4.mpg_vcd.zip --output s4.mpg_vcd.zip && unzip s4.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s5/video/s5.mpg_vcd.zip --output s5.mpg_vcd.zip && unzip s5.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s6/video/s6.mpg_vcd.zip --output s6.mpg_vcd.zip && unzip s6.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s7/video/s7.mpg_vcd.zip --output s7.mpg_vcd.zip && unzip s7.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s8/video/s8.mpg_vcd.zip --output s8.mpg_vcd.zip && unzip s8.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s9/video/s9.mpg_vcd.zip --output s9.mpg_vcd.zip && unzip s9.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s10/video/s10.mpg_vcd.zip --output s10.mpg_vcd.zip && unzip s10.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s11/video/s11.mpg_vcd.zip --output s11.mpg_vcd.zip && unzip s11.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s12/video/s12.mpg_vcd.zip --output s12.mpg_vcd.zip && unzip s12.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s13/video/s13.mpg_vcd.zip --output s13.mpg_vcd.zip && unzip s13.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s14/video/s14.mpg_vcd.zip --output s14.mpg_vcd.zip && unzip s14.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s15/video/s15.mpg_vcd.zip --output s15.mpg_vcd.zip && unzip s15.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s16/video/s16.mpg_vcd.zip --output s16.mpg_vcd.zip && unzip s16.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s17/video/s17.mpg_vcd.zip --output s17.mpg_vcd.zip && unzip s17.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s18/video/s18.mpg_vcd.zip --output s18.mpg_vcd.zip && unzip s18.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s19/video/s19.mpg_vcd.zip --output s19.mpg_vcd.zip && unzip s19.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s20/video/s20.mpg_vcd.zip --output s20.mpg_vcd.zip && unzip s20.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s21/video/s21.mpg_vcd.zip --output s21.mpg_vcd.zip && unzip s21.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s22/video/s22.mpg_vcd.zip --output s22.mpg_vcd.zip && unzip s22.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s23/video/s23.mpg_vcd.zip --output s23.mpg_vcd.zip && unzip s23.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s24/video/s24.mpg_vcd.zip --output s24.mpg_vcd.zip && unzip s24.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s25/video/s25.mpg_vcd.zip --output s25.mpg_vcd.zip && unzip s25.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s26/video/s26.mpg_vcd.zip --output s26.mpg_vcd.zip && unzip s26.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s27/video/s27.mpg_vcd.zip --output s27.mpg_vcd.zip && unzip s27.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s28/video/s28.mpg_vcd.zip --output s28.mpg_vcd.zip && unzip s28.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s29/video/s29.mpg_vcd.zip --output s29.mpg_vcd.zip && unzip s29.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s30/video/s30.mpg_vcd.zip --output s30.mpg_vcd.zip && unzip s30.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s31/video/s31.mpg_vcd.zip --output s31.mpg_vcd.zip && unzip s31.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s32/video/s32.mpg_vcd.zip --output s32.mpg_vcd.zip && unzip s32.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s33/video/s33.mpg_vcd.zip --output s33.mpg_vcd.zip && unzip s33.mpg_vcd.zip\n",
+      "cd ../data/mp4s && curl http://spandh.dcs.shef.ac.uk/gridcorpus/s34/video/s34.mpg_vcd.zip --output s34.mpg_vcd.zip && unzip s34.mpg_vcd.zip\n"
+     ]
+    }
+   ],
+   "source": [
+    "src_path = '../data/mp4s'\n",
+    "res = download_mp4(0, tot_movies, {'src_path':src_path})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}, set())\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "0"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "print (res)\n",
+    "os.system('rm -f {src_path}/*.zip && rm -f {src_path}/*/Thumbs.db'.format(src_path=src_path))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Preprocess Data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from preprocess_data import preprocess, find_files, Video"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tgt_path = '../data/datasets'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "0"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "os.makedirs('{tgt_path}'.format(tgt_path=tgt_path), exist_ok=True)\n",
+    "os.system('rm -rf {tgt_path}'.format(tgt_path=tgt_path))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "res = preprocess(0, tot_movies, {'src_path':src_path, 'tgt_path':tgt_path})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}, set())\n"
+     ]
+    }
+   ],
+   "source": [
+    "print (res)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python [default]",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}