You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/08/21 18:09:42 UTC

[incubator-mxnet] branch master updated: [MXNET-836] RNN Example for Scala (#11753)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 38f80af  [MXNET-836] RNN Example for Scala (#11753)
38f80af is described below

commit 38f80af6a4ac1ec1760772d5a407c39c876c16e6
Author: Lanking <la...@live.com>
AuthorDate: Tue Aug 21 11:09:23 2018 -0700

    [MXNET-836] RNN Example for Scala (#11753)
    
    * initial fix for RNN
    
    * add CI test
    
    * add encoding format
    
    * scala style fix
    
    * update readme
    
    * test char RNN works
    
    * ignore the test due to memory leaks
---
 .../org/apache/mxnetexamples/rnn/BucketIo.scala    |  19 +-
 .../scala/org/apache/mxnetexamples/rnn/Lstm.scala  |  97 ++++-----
 .../apache/mxnetexamples/rnn/LstmBucketing.scala   | 110 +++++-----
 .../scala/org/apache/mxnetexamples/rnn/README.md   |  48 +++++
 .../org/apache/mxnetexamples/rnn/TestCharRnn.scala |  96 +++++----
 .../apache/mxnetexamples/rnn/TrainCharRnn.scala    | 237 ++++++++++-----------
 .../scala/org/apache/mxnetexamples/rnn/Utils.scala |   3 -
 .../apache/mxnetexamples/rnn/ExampleRNNSuite.scala |  75 +++++++
 8 files changed, 399 insertions(+), 286 deletions(-)

diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
index d4b1707..6d414bb 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
@@ -34,7 +34,7 @@ object BucketIo {
   type ReadContent = String => String
 
   def defaultReadContent(path: String): String = {
-    Source.fromFile(path).mkString.replaceAll("\\. |\n", " <eos> ")
+    Source.fromFile(path, "UTF-8").mkString.replaceAll("\\. |\n", " <eos> ")
   }
 
   def defaultBuildVocab(path: String): Map[String, Int] = {
@@ -56,7 +56,7 @@ object BucketIo {
       val tmp = sentence.split(" ").filter(_.length() > 0)
       for (w <- tmp) yield theVocab(w)
     }
-    words.toArray
+    words
   }
 
   def defaultGenBuckets(sentences: Array[String], batchSize: Int,
@@ -162,8 +162,6 @@ object BucketIo {
       labelBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket)))
     }
 
-    private val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))
-
     private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey))
       tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
     }
@@ -208,12 +206,13 @@ object BucketIo {
         tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
       }
       val batchProvideLabel = ListMap("softmax_label" -> labelBuf.shape)
-      new DataBatch(IndexedSeq(dataBuf) ++ initStateArrays,
-                    IndexedSeq(labelBuf),
-                    getIndex(),
-                    getPad(),
-                    this.buckets(bucketIdx).asInstanceOf[AnyRef],
-                    batchProvideData, batchProvideLabel)
+      val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))
+      new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays,
+        IndexedSeq(labelBuf.copy()),
+        getIndex(),
+        getPad(),
+        this.buckets(bucketIdx).asInstanceOf[AnyRef],
+        batchProvideData, batchProvideLabel)
     }
 
     /**
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala
index bf29a47..872ef78 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala
@@ -18,13 +18,10 @@
 
 package org.apache.mxnetexamples.rnn
 
-import org.apache.mxnet.Symbol
+import org.apache.mxnet.{Shape, Symbol}
 
 import scala.collection.mutable.ArrayBuffer
 
-/**
- * @author Depeng Liang
- */
 object Lstm {
 
   final case class LSTMState(c: Symbol, h: Symbol)
@@ -35,27 +32,22 @@ object Lstm {
   def lstm(numHidden: Int, inData: Symbol, prevState: LSTMState,
            param: LSTMParam, seqIdx: Int, layerIdx: Int, dropout: Float = 0f): LSTMState = {
     val inDataa = {
-      if (dropout > 0f) Symbol.Dropout()()(Map("data" -> inData, "p" -> dropout))
+      if (dropout > 0f) Symbol.api.Dropout(data = Some(inData), p = Some(dropout))
       else inData
     }
-    val i2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_i2h")()(Map("data" -> inDataa,
-                                                       "weight" -> param.i2hWeight,
-                                                       "bias" -> param.i2hBias,
-                                                       "num_hidden" -> numHidden * 4))
-    val h2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_h2h")()(Map("data" -> prevState.h,
-                                                       "weight" -> param.h2hWeight,
-                                                       "bias" -> param.h2hBias,
-                                                       "num_hidden" -> numHidden * 4))
+    val i2h = Symbol.api.FullyConnected(data = Some(inDataa), weight = Some(param.i2hWeight),
+      bias = Some(param.i2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_i2h")
+    val h2h = Symbol.api.FullyConnected(data = Some(prevState.h), weight = Some(param.h2hWeight),
+      bias = Some(param.h2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_h2h")
     val gates = i2h + h2h
-    val sliceGates = Symbol.SliceChannel(s"t${seqIdx}_l${layerIdx}_slice")(
-      gates)(Map("num_outputs" -> 4))
-    val ingate = Symbol.Activation()()(Map("data" -> sliceGates.get(0), "act_type" -> "sigmoid"))
-    val inTransform = Symbol.Activation()()(Map("data" -> sliceGates.get(1), "act_type" -> "tanh"))
-    val forgetGate = Symbol.Activation()()(
-      Map("data" -> sliceGates.get(2), "act_type" -> "sigmoid"))
-    val outGate = Symbol.Activation()()(Map("data" -> sliceGates.get(3), "act_type" -> "sigmoid"))
+    val sliceGates = Symbol.api.SliceChannel(data = Some(gates), num_outputs = 4,
+      name = s"t${seqIdx}_l${layerIdx}_slice")
+    val ingate = Symbol.api.Activation(data = Some(sliceGates.get(0)), act_type = "sigmoid")
+    val inTransform = Symbol.api.Activation(data = Some(sliceGates.get(1)), act_type = "tanh")
+    val forgetGate = Symbol.api.Activation(data = Some(sliceGates.get(2)), act_type = "sigmoid")
+    val outGate = Symbol.api.Activation(data = Some(sliceGates.get(3)), act_type = "sigmoid")
     val nextC = (forgetGate * prevState.c) + (ingate * inTransform)
-    val nextH = outGate * Symbol.Activation()()(Map("data" -> nextC, "act_type" -> "tanh"))
+    val nextH = outGate * Symbol.api.Activation(data = Some(nextC), "tanh")
     LSTMState(c = nextC, h = nextH)
   }
 
@@ -74,11 +66,11 @@ object Lstm {
     val lastStatesBuf = ArrayBuffer[LSTMState]()
     for (i <- 0 until numLstmLayer) {
       paramCellsBuf.append(LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
-                                     i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
-                                     h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
-                                     h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
+        i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
+        h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
+        h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
       lastStatesBuf.append(LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
-                                     h = Symbol.Variable(s"l${i}_init_h_beta")))
+        h = Symbol.Variable(s"l${i}_init_h_beta")))
     }
     val paramCells = paramCellsBuf.toArray
     val lastStates = lastStatesBuf.toArray
@@ -87,10 +79,10 @@ object Lstm {
     // embeding layer
     val data = Symbol.Variable("data")
     var label = Symbol.Variable("softmax_label")
-    val embed = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
-                                           "weight" -> embedWeight, "output_dim" -> numEmbed))
-    val wordvec = Symbol.SliceChannel()()(
-      Map("data" -> embed, "num_outputs" -> seqLen, "squeeze_axis" -> 1))
+    val embed = Symbol.api.Embedding(data = Some(data), input_dim = inputSize,
+      weight = Some(embedWeight), output_dim = numEmbed, name = "embed")
+    val wordvec = Symbol.api.SliceChannel(data = Some(embed),
+      num_outputs = seqLen, squeeze_axis = Some(true))
 
     val hiddenAll = ArrayBuffer[Symbol]()
     var dpRatio = 0f
@@ -101,22 +93,23 @@ object Lstm {
       for (i <- 0 until numLstmLayer) {
         if (i == 0) dpRatio = 0f else dpRatio = dropout
         val nextState = lstm(numHidden, inData = hidden,
-                             prevState = lastStates(i),
-                             param = paramCells(i),
-                             seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
+          prevState = lastStates(i),
+          param = paramCells(i),
+          seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
         hidden = nextState.h
         lastStates(i) = nextState
       }
       // decoder
-      if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
+      if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
       hiddenAll.append(hidden)
     }
-    val hiddenConcat = Symbol.Concat()(hiddenAll: _*)(Map("dim" -> 0))
-    val pred = Symbol.FullyConnected("pred")()(Map("data" -> hiddenConcat, "num_hidden" -> numLabel,
-                                                   "weight" -> clsWeight, "bias" -> clsBias))
-    label = Symbol.transpose()(label)()
-    label = Symbol.Reshape()()(Map("data" -> label, "target_shape" -> "(0,)"))
-    val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> pred, "label" -> label))
+    val hiddenConcat = Symbol.api.Concat(data = hiddenAll.toArray, num_args = hiddenAll.length,
+      dim = Some(0))
+    val pred = Symbol.api.FullyConnected(data = Some(hiddenConcat), num_hidden = numLabel,
+      weight = Some(clsWeight), bias = Some(clsBias))
+    label = Symbol.api.transpose(data = Some(label))
+    label = Symbol.api.Reshape(data = Some(label), target_shape = Some(Shape(0)))
+    val sm = Symbol.api.SoftmaxOutput(data = Some(pred), label = Some(label), name = "softmax")
     sm
   }
 
@@ -131,35 +124,35 @@ object Lstm {
     var lastStates = Array[LSTMState]()
     for (i <- 0 until numLstmLayer) {
       paramCells = paramCells :+ LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
-                                           i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
-                                           h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
-                                           h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
+        i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
+        h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
+        h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
       lastStates = lastStates :+ LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
-                                           h = Symbol.Variable(s"l${i}_init_h_beta"))
+        h = Symbol.Variable(s"l${i}_init_h_beta"))
     }
     assert(lastStates.length == numLstmLayer)
 
     val data = Symbol.Variable("data")
 
-    var hidden = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
-                                             "weight" -> embedWeight, "output_dim" -> numEmbed))
+    var hidden = Symbol.api.Embedding(data = Some(data), input_dim = inputSize,
+      weight = Some(embedWeight), output_dim = numEmbed, name = "embed")
 
     var dpRatio = 0f
     // stack LSTM
     for (i <- 0 until numLstmLayer) {
       if (i == 0) dpRatio = 0f else dpRatio = dropout
       val nextState = lstm(numHidden, inData = hidden,
-                           prevState = lastStates(i),
-                           param = paramCells(i),
-                           seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
+        prevState = lastStates(i),
+        param = paramCells(i),
+        seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
       hidden = nextState.h
       lastStates(i) = nextState
     }
     // decoder
-    if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
-    val fc = Symbol.FullyConnected("pred")()(Map("data" -> hidden, "num_hidden" -> numLabel,
-                                      "weight" -> clsWeight, "bias" -> clsBias))
-    val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> fc))
+    if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
+    val fc = Symbol.api.FullyConnected(data = Some(hidden),
+      num_hidden = numLabel, weight = Some(clsWeight), bias = Some(clsBias))
+    val sm = Symbol.api.SoftmaxOutput(data = Some(fc), name = "softmax")
     var output = Array(sm)
     for (state <- lastStates) {
       output = output :+ state.c
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
index 44ee6e7..f7a01ba 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
@@ -30,9 +30,8 @@ import org.apache.mxnet.module.BucketingModule
 import org.apache.mxnet.module.FitParams
 
 /**
- * Bucketing LSTM examples
- * @author Yizhi Liu
- */
+  * Bucketing LSTM examples
+  */
 class LstmBucketing {
   @Option(name = "--data-train", usage = "training set")
   private val dataTrain: String = "example/rnn/sherlockholmes.train.txt"
@@ -61,6 +60,60 @@ object LstmBucketing {
     Math.exp(loss / labelArr.length).toFloat
   }
 
+  def runTraining(trainData : String, validationData : String,
+                  ctx : Array[Context], numEpoch : Int): Unit = {
+    val batchSize = 32
+    val buckets = Array(10, 20, 30, 40, 50, 60)
+    val numHidden = 200
+    val numEmbed = 200
+    val numLstmLayer = 2
+
+    logger.info("Building vocab ...")
+    val vocab = BucketIo.defaultBuildVocab(trainData)
+
+    def BucketSymGen(key: AnyRef):
+    (Symbol, IndexedSeq[String], IndexedSeq[String]) = {
+      val seqLen = key.asInstanceOf[Int]
+      val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
+        numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
+      (sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
+    }
+
+    val initC = (0 until numLstmLayer).map(l =>
+      (s"l${l}_init_c_beta", (batchSize, numHidden))
+    )
+    val initH = (0 until numLstmLayer).map(l =>
+      (s"l${l}_init_h_beta", (batchSize, numHidden))
+    )
+    val initStates = initC ++ initH
+
+    val dataTrain = new BucketSentenceIter(trainData, vocab,
+      buckets, batchSize, initStates)
+    val dataVal = new BucketSentenceIter(validationData, vocab,
+      buckets, batchSize, initStates)
+
+    val model = new BucketingModule(
+      symGen = BucketSymGen,
+      defaultBucketKey = dataTrain.defaultBucketKey,
+      contexts = ctx)
+
+    val fitParams = new FitParams()
+    fitParams.setEvalMetric(
+      new CustomMetric(perplexity, name = "perplexity"))
+    fitParams.setKVStore("device")
+    fitParams.setOptimizer(
+      new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
+    fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
+    fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))
+
+    logger.info("Start training ...")
+    model.fit(
+      trainData = dataTrain,
+      evalData = Some(dataVal),
+      numEpoch = numEpoch, fitParams)
+    logger.info("Finished training...")
+  }
+
   def main(args: Array[String]): Unit = {
     val inst = new LstmBucketing
     val parser: CmdLineParser = new CmdLineParser(inst)
@@ -71,56 +124,7 @@ object LstmBucketing {
         else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
         else Array(Context.cpu(0))
 
-      val batchSize = 32
-      val buckets = Array(10, 20, 30, 40, 50, 60)
-      val numHidden = 200
-      val numEmbed = 200
-      val numLstmLayer = 2
-
-      logger.info("Building vocab ...")
-      val vocab = BucketIo.defaultBuildVocab(inst.dataTrain)
-
-      def BucketSymGen(key: AnyRef):
-        (Symbol, IndexedSeq[String], IndexedSeq[String]) = {
-        val seqLen = key.asInstanceOf[Int]
-        val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
-          numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
-        (sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
-      }
-
-      val initC = (0 until numLstmLayer).map(l =>
-        (s"l${l}_init_c_beta", (batchSize, numHidden))
-      )
-      val initH = (0 until numLstmLayer).map(l =>
-        (s"l${l}_init_h_beta", (batchSize, numHidden))
-      )
-      val initStates = initC ++ initH
-
-      val dataTrain = new BucketSentenceIter(inst.dataTrain, vocab,
-        buckets, batchSize, initStates)
-      val dataVal = new BucketSentenceIter(inst.dataVal, vocab,
-        buckets, batchSize, initStates)
-
-      val model = new BucketingModule(
-        symGen = BucketSymGen,
-        defaultBucketKey = dataTrain.defaultBucketKey,
-        contexts = contexts)
-
-      val fitParams = new FitParams()
-      fitParams.setEvalMetric(
-        new CustomMetric(perplexity, name = "perplexity"))
-      fitParams.setKVStore("device")
-      fitParams.setOptimizer(
-        new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
-      fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
-      fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))
-
-      logger.info("Start training ...")
-      model.fit(
-        trainData = dataTrain,
-        evalData = Some(dataVal),
-        numEpoch = inst.numEpoch, fitParams)
-      logger.info("Finished training...")
+      runTraining(inst.dataTrain, inst.dataVal, contexts, 5)
     } catch {
       case ex: Exception =>
         logger.error(ex.getMessage, ex)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md
new file mode 100644
index 0000000..5289fc7
--- /dev/null
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md
@@ -0,0 +1,48 @@
+# RNN Example for MXNet Scala
+This folder contains the following examples writing in new Scala type-safe API:
+- [x] LSTM Bucketing
+- [x] CharRNN Inference : Generate similar text based on the model
+- [x] CharRNN Training: Training the language model using RNN
+
+These example is only for Illustration and not modeled to achieve the best accuracy.
+
+## Setup
+### Download the Network Definition, Weights and Training Data
+`obama.zip` contains the training inputs (Obama's speech) for CharCNN examples and `sherlockholmes` contains the data for LSTM Bucketing
+```bash
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/obama.zip
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.train.txt
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.valid.txt
+```
+### Unzip the file
+```bash
+unzip obama.zip
+```
+### Arguement Configuration
+Then you need to define the arguments that you would like to pass in the model:
+
+#### LSTM Bucketing
+```bash
+--data-train
+<path>/sherlockholmes.train.txt
+--data-val
+<path>/sherlockholmes.valid.txt
+--cpus
+<num_cpus>
+--gpus
+<num_gpu>
+```
+#### TrainCharRnn
+```bash
+--data-path
+<path>/obama.txt
+--save-model-path
+<path>/
+```
+#### TestCharRnn
+```bash
+--data-path
+<path>/obama.txt
+--model-prefix
+<path>/obama
+```
\ No newline at end of file
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
index 243b70c..4786d5d 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
@@ -25,66 +25,68 @@ import scala.collection.JavaConverters._
 /**
  * Follows the demo, to test the char rnn:
  * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
- * @author Depeng Liang
  */
 object TestCharRnn {
 
   private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn])
 
-  def main(args: Array[String]): Unit = {
-    val stcr = new TestCharRnn
-    val parser: CmdLineParser = new CmdLineParser(stcr)
-    try {
-      parser.parseArgument(args.toList.asJava)
-      assert(stcr.dataPath != null && stcr.modelPrefix != null && stcr.starterSentence != null)
+  def runTestCharRNN(dataPath: String, modelPrefix: String, starterSentence : String): Unit = {
+    // The batch size for training
+    val batchSize = 32
+    // We can support various length input
+    // For this problem, we cut each input sentence to length of 129
+    // So we only need fix length bucket
+    val buckets = List(129)
+    // hidden unit in LSTM cell
+    val numHidden = 512
+    // embedding dimension, which is, map a char to a 256 dim vector
+    val numEmbed = 256
+    // number of lstm layer
+    val numLstmLayer = 3
 
-      // The batch size for training
-      val batchSize = 32
-      // We can support various length input
-      // For this problem, we cut each input sentence to length of 129
-      // So we only need fix length bucket
-      val buckets = List(129)
-      // hidden unit in LSTM cell
-      val numHidden = 512
-      // embedding dimension, which is, map a char to a 256 dim vector
-      val numEmbed = 256
-      // number of lstm layer
-      val numLstmLayer = 3
+    // build char vocabluary from input
+    val vocab = Utils.buildVocab(dataPath)
 
-      // build char vocabluary from input
-      val vocab = Utils.buildVocab(stcr.dataPath)
+    // load from check-point
+    val (_, argParams, _) = Model.loadCheckpoint(modelPrefix, 75)
 
-      // load from check-point
-      val (_, argParams, _) = Model.loadCheckpoint(stcr.modelPrefix, 75)
+    // build an inference model
+    val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
+      numHidden = numHidden, numEmbed = numEmbed,
+      numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
 
-      // build an inference model
-      val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
-                           numHidden = numHidden, numEmbed = numEmbed,
-                           numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
+    // generate a sequence of 1200 chars
+    val seqLength = 1200
+    val inputNdarray = NDArray.zeros(1)
+    val revertVocab = Utils.makeRevertVocab(vocab)
 
-      // generate a sequence of 1200 chars
-      val seqLength = 1200
-      val inputNdarray = NDArray.zeros(1)
-      val revertVocab = Utils.makeRevertVocab(vocab)
+    // Feel free to change the starter sentence
+    var output = starterSentence
+    val randomSample = true
+    var newSentence = true
+    val ignoreLength = output.length()
 
-      // Feel free to change the starter sentence
-      var output = stcr.starterSentence
-      val randomSample = true
-      var newSentence = true
-      val ignoreLength = output.length()
+    for (i <- 0 until seqLength) {
+      if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray)
+      else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray)
+      val prob = model.forward(inputNdarray, newSentence)
+      newSentence = false
+      val nextChar = Utils.makeOutput(prob, revertVocab, randomSample)
+      if (nextChar == "") newSentence = true
+      if (i >= ignoreLength) output = output ++ nextChar
+    }
 
-      for (i <- 0 until seqLength) {
-        if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray)
-        else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray)
-        val prob = model.forward(inputNdarray, newSentence)
-        newSentence = false
-        val nextChar = Utils.makeOutput(prob, revertVocab, randomSample)
-        if (nextChar == "") newSentence = true
-        if (i >= ignoreLength) output = output ++ nextChar
-      }
+    // Let's see what we can learned from char in Obama's speech.
+    logger.info(output)
+  }
 
-      // Let's see what we can learned from char in Obama's speech.
-      logger.info(output)
+  def main(args: Array[String]): Unit = {
+    val stcr = new TestCharRnn
+    val parser: CmdLineParser = new CmdLineParser(stcr)
+    try {
+      parser.parseArgument(args.toList.asJava)
+      assert(stcr.dataPath != null && stcr.modelPrefix != null && stcr.starterSentence != null)
+      runTestCharRNN(stcr.dataPath, stcr.modelPrefix, stcr.starterSentence)
     } catch {
       case ex: Exception => {
         logger.error(ex.getMessage, ex)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
index 3afb936..fb59705 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
@@ -24,143 +24,144 @@ import scala.collection.JavaConverters._
 import org.apache.mxnet.optimizer.Adam
 
 /**
- * Follows the demo, to train the char rnn:
- * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
- * @author Depeng Liang
- */
+  * Follows the demo, to train the char rnn:
+  * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
+  */
 object TrainCharRnn {
 
   private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn])
 
-  def main(args: Array[String]): Unit = {
-    val incr = new TrainCharRnn
-    val parser: CmdLineParser = new CmdLineParser(incr)
-    try {
-      parser.parseArgument(args.toList.asJava)
-      assert(incr.dataPath != null && incr.saveModelPath != null)
-
-      // The batch size for training
-      val batchSize = 32
-      // We can support various length input
-      // For this problem, we cut each input sentence to length of 129
-      // So we only need fix length bucket
-      val buckets = Array(129)
-      // hidden unit in LSTM cell
-      val numHidden = 512
-      // embedding dimension, which is, map a char to a 256 dim vector
-      val numEmbed = 256
-      // number of lstm layer
-      val numLstmLayer = 3
-      // we will show a quick demo in 2 epoch
-      // and we will see result by training 75 epoch
-      val numEpoch = 75
-      // learning rate
-      val learningRate = 0.001f
-      // we will use pure sgd without momentum
-      val momentum = 0.0f
-
-      val ctx = if (incr.gpu == -1) Context.cpu() else Context.gpu(incr.gpu)
-      val vocab = Utils.buildVocab(incr.dataPath)
-
-      // generate symbol for a length
-      def symGen(seqLen: Int): Symbol = {
-        Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1,
-                    numHidden = numHidden, numEmbed = numEmbed,
-                    numLabel = vocab.size + 1, dropout = 0.2f)
-      }
+  def runTrainCharRnn(dataPath: String, saveModelPath: String,
+                      ctx : Context, numEpoch : Int): Unit = {
+    // The batch size for training
+    val batchSize = 32
+    // We can support various length input
+    // For this problem, we cut each input sentence to length of 129
+    // So we only need fix length bucket
+    val buckets = Array(129)
+    // hidden unit in LSTM cell
+    val numHidden = 512
+    // embedding dimension, which is, map a char to a 256 dim vector
+    val numEmbed = 256
+    // number of lstm layer
+    val numLstmLayer = 3
+    // we will show a quick demo in 2 epoch
+    // learning rate
+    val learningRate = 0.001f
+    // we will use pure sgd without momentum
+    val momentum = 0.0f
+
+    val vocab = Utils.buildVocab(dataPath)
+
+    // generate symbol for a length
+    def symGen(seqLen: Int): Symbol = {
+      Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1,
+        numHidden = numHidden, numEmbed = numEmbed,
+        numLabel = vocab.size + 1, dropout = 0.2f)
+    }
 
-      // initalize states for LSTM
-      val initC = for (l <- 0 until numLstmLayer)
-        yield (s"l${l}_init_c_beta", (batchSize, numHidden))
-      val initH = for (l <- 0 until numLstmLayer)
-        yield (s"l${l}_init_h_beta", (batchSize, numHidden))
-      val initStates = initC ++ initH
+    // initalize states for LSTM
+    val initC = for (l <- 0 until numLstmLayer)
+      yield (s"l${l}_init_c_beta", (batchSize, numHidden))
+    val initH = for (l <- 0 until numLstmLayer)
+      yield (s"l${l}_init_h_beta", (batchSize, numHidden))
+    val initStates = initC ++ initH
 
-      val dataTrain = new BucketIo.BucketSentenceIter(incr.dataPath, vocab, buckets,
-                                          batchSize, initStates, seperateChar = "\n",
-                                          text2Id = Utils.text2Id, readContent = Utils.readContent)
+    val dataTrain = new BucketIo.BucketSentenceIter(dataPath, vocab, buckets,
+      batchSize, initStates, seperateChar = "\n",
+      text2Id = Utils.text2Id, readContent = Utils.readContent)
 
-      // the network symbol
-      val symbol = symGen(buckets(0))
+    // the network symbol
+    val symbol = symGen(buckets(0))
 
-      val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel
-      val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels)
+    val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel
+    val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels)
 
-      val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
+    val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
 
-      val argNames = symbol.listArguments()
-      val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap
-      val auxNames = symbol.listAuxiliaryStates()
-      val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap
+    val argNames = symbol.listArguments()
+    val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap
+    val auxNames = symbol.listAuxiliaryStates()
+    val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap
 
-      val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
-        !datasAndLabels.contains(name)
-      }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap
+    val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
+      !datasAndLabels.contains(name)
+    }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap
 
-      argDict.foreach { case (name, ndArray) =>
-        if (!datasAndLabels.contains(name)) {
-          initializer.initWeight(name, ndArray)
-        }
+    argDict.foreach { case (name, ndArray) =>
+      if (!datasAndLabels.contains(name)) {
+        initializer.initWeight(name, ndArray)
       }
+    }
 
-      val data = argDict("data")
-      val label = argDict("softmax_label")
+    val data = argDict("data")
+    val label = argDict("softmax_label")
 
-      val executor = symbol.bind(ctx, argDict, gradDict)
+    val executor = symbol.bind(ctx, argDict, gradDict)
 
-      val opt = new Adam(learningRate = learningRate, wd = 0.0001f)
+    val opt = new Adam(learningRate = learningRate, wd = 0.0001f)
 
-      val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
-        (idx, name, grad, opt.createState(idx, argDict(name)))
-      }
+    val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
+      (idx, name, grad, opt.createState(idx, argDict(name)))
+    }
 
-      val evalMetric = new CustomMetric(Utils.perplexity, "perplexity")
-      val batchEndCallback = new Callback.Speedometer(batchSize, 50)
-      val epochEndCallback = Utils.doCheckpoint(s"${incr.saveModelPath}/obama")
-
-      for (epoch <- 0 until numEpoch) {
-        // Training phase
-        val tic = System.currentTimeMillis
-        evalMetric.reset()
-        var nBatch = 0
-        var epochDone = false
-        // Iterate over training data.
-        dataTrain.reset()
-        while (!epochDone) {
-          var doReset = true
-          while (doReset && dataTrain.hasNext) {
-            val dataBatch = dataTrain.next()
-
-            data.set(dataBatch.data(0))
-            label.set(dataBatch.label(0))
-            executor.forward(isTrain = true)
-            executor.backward()
-            paramsGrads.foreach { case (idx, name, grad, optimState) =>
-              opt.update(idx, argDict(name), grad, optimState)
-            }
-
-            // evaluate at end, so out_cpu_array can lazy copy
-            evalMetric.update(dataBatch.label, executor.outputs)
-
-            nBatch += 1
-            batchEndCallback.invoke(epoch, nBatch, evalMetric)
+    val evalMetric = new CustomMetric(Utils.perplexity, "perplexity")
+    val batchEndCallback = new Callback.Speedometer(batchSize, 50)
+    val epochEndCallback = Utils.doCheckpoint(s"${saveModelPath}/obama")
+
+    for (epoch <- 0 until numEpoch) {
+      // Training phase
+      val tic = System.currentTimeMillis
+      evalMetric.reset()
+      var nBatch = 0
+      var epochDone = false
+      // Iterate over training data.
+      dataTrain.reset()
+      while (!epochDone) {
+        var doReset = true
+        while (doReset && dataTrain.hasNext) {
+          val dataBatch = dataTrain.next()
+
+          data.set(dataBatch.data(0))
+          label.set(dataBatch.label(0))
+          executor.forward(isTrain = true)
+          executor.backward()
+          paramsGrads.foreach { case (idx, name, grad, optimState) =>
+            opt.update(idx, argDict(name), grad, optimState)
           }
-          if (doReset) {
-            dataTrain.reset()
-          }
-          // this epoch is done
-          epochDone = true
+
+          // evaluate at end, so out_cpu_array can lazy copy
+          evalMetric.update(dataBatch.label, executor.outputs)
+
+          nBatch += 1
+          batchEndCallback.invoke(epoch, nBatch, evalMetric)
         }
-        val (name, value) = evalMetric.get
-        name.zip(value).foreach { case (n, v) =>
-          logger.info(s"Epoch[$epoch] Train-$n=$v")
+        if (doReset) {
+          dataTrain.reset()
         }
-        val toc = System.currentTimeMillis
-        logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
-
-        epochEndCallback.invoke(epoch, symbol, argDict, auxDict)
+        // this epoch is done
+        epochDone = true
       }
-      executor.dispose()
+      val (name, value) = evalMetric.get
+      name.zip(value).foreach { case (n, v) =>
+        logger.info(s"Epoch[$epoch] Train-$n=$v")
+      }
+      val toc = System.currentTimeMillis
+      logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
+
+      epochEndCallback.invoke(epoch, symbol, argDict, auxDict)
+    }
+    executor.dispose()
+  }
+
+  def main(args: Array[String]): Unit = {
+    val incr = new TrainCharRnn
+    val parser: CmdLineParser = new CmdLineParser(incr)
+    try {
+      parser.parseArgument(args.toList.asJava)
+      val ctx = if (incr.gpu == -1) Context.cpu() else Context.gpu(incr.gpu)
+      assert(incr.dataPath != null && incr.saveModelPath != null)
+      runTrainCharRnn(incr.dataPath, incr.saveModelPath, ctx, 75)
     } catch {
       case ex: Exception => {
         logger.error(ex.getMessage, ex)
@@ -172,12 +173,6 @@ object TrainCharRnn {
 }
 
 class TrainCharRnn {
-  /*
-   * Get Training Data:  E.g.
-   * mkdir data; cd data
-   * wget "http://data.mxnet.io/mxnet/data/char_lstm.zip"
-   * unzip -o char_lstm.zip
-   */
   @Option(name = "--data-path", usage = "the input train data file")
   private val dataPath: String = "./data/obama.txt"
   @Option(name = "--save-model-path", usage = "the model saving path")
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala
index c290230..3f9a984 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala
@@ -25,9 +25,6 @@ import org.apache.mxnet.Model
 import org.apache.mxnet.Symbol
 import scala.util.Random
 
-/**
- * @author Depeng Liang
- */
 object Utils {
 
   def readContent(path: String): String = Source.fromFile(path).mkString
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
new file mode 100644
index 0000000..b393a43
--- /dev/null
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.rnn
+
+
+import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnetexamples.Util
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
+import org.slf4j.LoggerFactory
+
+import scala.sys.process.Process
+
+@Ignore
+class ExampleRNNSuite extends FunSuite with BeforeAndAfterAll {
+  private val logger = LoggerFactory.getLogger(classOf[ExampleRNNSuite])
+
+  override def beforeAll(): Unit = {
+    logger.info("Downloading LSTM model")
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    logger.info("tempDirPath: %s".format(tempDirPath))
+    val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/"
+    Util.downloadUrl(baseUrl + "obama.zip", tempDirPath + "/RNN/obama.zip")
+    Util.downloadUrl(baseUrl + "sherlockholmes.train.txt",
+      tempDirPath + "/RNN/sherlockholmes.train.txt")
+    Util.downloadUrl(baseUrl + "sherlockholmes.valid.txt",
+      tempDirPath + "/RNN/sherlockholmes.valid.txt")
+    // TODO: Need to confirm with Windows
+    Process(s"unzip $tempDirPath/RNN/obama.zip -d $tempDirPath/RNN/") !
+  }
+
+  test("Example CI: Test LSTM Bucketing") {
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    var ctx = Context.cpu()
+    if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+      System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+      ctx = Context.gpu()
+    }
+    LstmBucketing.runTraining(tempDirPath + "/RNN/sherlockholmes.train.txt",
+        tempDirPath + "/RNN/sherlockholmes.valid.txt", Array(ctx), 1)
+  }
+
+  test("Example CI: Test TrainCharRNN") {
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+      System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+      val ctx = Context.gpu()
+      TrainCharRnn.runTrainCharRnn(tempDirPath + "/RNN/obama.txt",
+          tempDirPath, ctx, 1)
+    } else {
+      logger.info("CPU not supported for this test, skipped...")
+    }
+  }
+
+  test("Example CI: Test TestCharRNN") {
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    val ctx = Context.gpu()
+    TestCharRnn.runTestCharRNN(tempDirPath + "/RNN/obama.txt",
+        tempDirPath + "/RNN/obama", "The joke")
+  }
+}