You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/21 18:09:25 UTC

[GitHub] nswamy closed pull request #11753: [MXNET-836] RNN Example for Scala

nswamy closed pull request #11753: [MXNET-836] RNN Example for Scala
URL: https://github.com/apache/incubator-mxnet/pull/11753
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
index d4b17074d48..6d414bb0328 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
@@ -34,7 +34,7 @@ object BucketIo {
   type ReadContent = String => String
 
   def defaultReadContent(path: String): String = {
-    Source.fromFile(path).mkString.replaceAll("\\. |\n", " <eos> ")
+    Source.fromFile(path, "UTF-8").mkString.replaceAll("\\. |\n", " <eos> ")
   }
 
   def defaultBuildVocab(path: String): Map[String, Int] = {
@@ -56,7 +56,7 @@ object BucketIo {
       val tmp = sentence.split(" ").filter(_.length() > 0)
       for (w <- tmp) yield theVocab(w)
     }
-    words.toArray
+    words
   }
 
   def defaultGenBuckets(sentences: Array[String], batchSize: Int,
@@ -162,8 +162,6 @@ object BucketIo {
       labelBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket)))
     }
 
-    private val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))
-
     private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey))
       tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
     }
@@ -208,12 +206,13 @@ object BucketIo {
         tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
       }
       val batchProvideLabel = ListMap("softmax_label" -> labelBuf.shape)
-      new DataBatch(IndexedSeq(dataBuf) ++ initStateArrays,
-                    IndexedSeq(labelBuf),
-                    getIndex(),
-                    getPad(),
-                    this.buckets(bucketIdx).asInstanceOf[AnyRef],
-                    batchProvideData, batchProvideLabel)
+      val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))
+      new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays,
+        IndexedSeq(labelBuf.copy()),
+        getIndex(),
+        getPad(),
+        this.buckets(bucketIdx).asInstanceOf[AnyRef],
+        batchProvideData, batchProvideLabel)
     }
 
     /**
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala
index bf29a47fcf8..872ef7871fb 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala
@@ -18,13 +18,10 @@
 
 package org.apache.mxnetexamples.rnn
 
-import org.apache.mxnet.Symbol
+import org.apache.mxnet.{Shape, Symbol}
 
 import scala.collection.mutable.ArrayBuffer
 
-/**
- * @author Depeng Liang
- */
 object Lstm {
 
   final case class LSTMState(c: Symbol, h: Symbol)
@@ -35,27 +32,22 @@ object Lstm {
   def lstm(numHidden: Int, inData: Symbol, prevState: LSTMState,
            param: LSTMParam, seqIdx: Int, layerIdx: Int, dropout: Float = 0f): LSTMState = {
     val inDataa = {
-      if (dropout > 0f) Symbol.Dropout()()(Map("data" -> inData, "p" -> dropout))
+      if (dropout > 0f) Symbol.api.Dropout(data = Some(inData), p = Some(dropout))
       else inData
     }
-    val i2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_i2h")()(Map("data" -> inDataa,
-                                                       "weight" -> param.i2hWeight,
-                                                       "bias" -> param.i2hBias,
-                                                       "num_hidden" -> numHidden * 4))
-    val h2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_h2h")()(Map("data" -> prevState.h,
-                                                       "weight" -> param.h2hWeight,
-                                                       "bias" -> param.h2hBias,
-                                                       "num_hidden" -> numHidden * 4))
+    val i2h = Symbol.api.FullyConnected(data = Some(inDataa), weight = Some(param.i2hWeight),
+      bias = Some(param.i2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_i2h")
+    val h2h = Symbol.api.FullyConnected(data = Some(prevState.h), weight = Some(param.h2hWeight),
+      bias = Some(param.h2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_h2h")
     val gates = i2h + h2h
-    val sliceGates = Symbol.SliceChannel(s"t${seqIdx}_l${layerIdx}_slice")(
-      gates)(Map("num_outputs" -> 4))
-    val ingate = Symbol.Activation()()(Map("data" -> sliceGates.get(0), "act_type" -> "sigmoid"))
-    val inTransform = Symbol.Activation()()(Map("data" -> sliceGates.get(1), "act_type" -> "tanh"))
-    val forgetGate = Symbol.Activation()()(
-      Map("data" -> sliceGates.get(2), "act_type" -> "sigmoid"))
-    val outGate = Symbol.Activation()()(Map("data" -> sliceGates.get(3), "act_type" -> "sigmoid"))
+    val sliceGates = Symbol.api.SliceChannel(data = Some(gates), num_outputs = 4,
+      name = s"t${seqIdx}_l${layerIdx}_slice")
+    val ingate = Symbol.api.Activation(data = Some(sliceGates.get(0)), act_type = "sigmoid")
+    val inTransform = Symbol.api.Activation(data = Some(sliceGates.get(1)), act_type = "tanh")
+    val forgetGate = Symbol.api.Activation(data = Some(sliceGates.get(2)), act_type = "sigmoid")
+    val outGate = Symbol.api.Activation(data = Some(sliceGates.get(3)), act_type = "sigmoid")
     val nextC = (forgetGate * prevState.c) + (ingate * inTransform)
-    val nextH = outGate * Symbol.Activation()()(Map("data" -> nextC, "act_type" -> "tanh"))
+    val nextH = outGate * Symbol.api.Activation(data = Some(nextC), "tanh")
     LSTMState(c = nextC, h = nextH)
   }
 
@@ -74,11 +66,11 @@ object Lstm {
     val lastStatesBuf = ArrayBuffer[LSTMState]()
     for (i <- 0 until numLstmLayer) {
       paramCellsBuf.append(LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
-                                     i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
-                                     h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
-                                     h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
+        i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
+        h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
+        h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
       lastStatesBuf.append(LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
-                                     h = Symbol.Variable(s"l${i}_init_h_beta")))
+        h = Symbol.Variable(s"l${i}_init_h_beta")))
     }
     val paramCells = paramCellsBuf.toArray
     val lastStates = lastStatesBuf.toArray
@@ -87,10 +79,10 @@ object Lstm {
     // embeding layer
     val data = Symbol.Variable("data")
     var label = Symbol.Variable("softmax_label")
-    val embed = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
-                                           "weight" -> embedWeight, "output_dim" -> numEmbed))
-    val wordvec = Symbol.SliceChannel()()(
-      Map("data" -> embed, "num_outputs" -> seqLen, "squeeze_axis" -> 1))
+    val embed = Symbol.api.Embedding(data = Some(data), input_dim = inputSize,
+      weight = Some(embedWeight), output_dim = numEmbed, name = "embed")
+    val wordvec = Symbol.api.SliceChannel(data = Some(embed),
+      num_outputs = seqLen, squeeze_axis = Some(true))
 
     val hiddenAll = ArrayBuffer[Symbol]()
     var dpRatio = 0f
@@ -101,22 +93,23 @@ object Lstm {
       for (i <- 0 until numLstmLayer) {
         if (i == 0) dpRatio = 0f else dpRatio = dropout
         val nextState = lstm(numHidden, inData = hidden,
-                             prevState = lastStates(i),
-                             param = paramCells(i),
-                             seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
+          prevState = lastStates(i),
+          param = paramCells(i),
+          seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
         hidden = nextState.h
         lastStates(i) = nextState
       }
       // decoder
-      if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
+      if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
       hiddenAll.append(hidden)
     }
-    val hiddenConcat = Symbol.Concat()(hiddenAll: _*)(Map("dim" -> 0))
-    val pred = Symbol.FullyConnected("pred")()(Map("data" -> hiddenConcat, "num_hidden" -> numLabel,
-                                                   "weight" -> clsWeight, "bias" -> clsBias))
-    label = Symbol.transpose()(label)()
-    label = Symbol.Reshape()()(Map("data" -> label, "target_shape" -> "(0,)"))
-    val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> pred, "label" -> label))
+    val hiddenConcat = Symbol.api.Concat(data = hiddenAll.toArray, num_args = hiddenAll.length,
+      dim = Some(0))
+    val pred = Symbol.api.FullyConnected(data = Some(hiddenConcat), num_hidden = numLabel,
+      weight = Some(clsWeight), bias = Some(clsBias))
+    label = Symbol.api.transpose(data = Some(label))
+    label = Symbol.api.Reshape(data = Some(label), target_shape = Some(Shape(0)))
+    val sm = Symbol.api.SoftmaxOutput(data = Some(pred), label = Some(label), name = "softmax")
     sm
   }
 
@@ -131,35 +124,35 @@ object Lstm {
     var lastStates = Array[LSTMState]()
     for (i <- 0 until numLstmLayer) {
       paramCells = paramCells :+ LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
-                                           i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
-                                           h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
-                                           h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
+        i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
+        h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
+        h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
       lastStates = lastStates :+ LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
-                                           h = Symbol.Variable(s"l${i}_init_h_beta"))
+        h = Symbol.Variable(s"l${i}_init_h_beta"))
     }
     assert(lastStates.length == numLstmLayer)
 
     val data = Symbol.Variable("data")
 
-    var hidden = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
-                                             "weight" -> embedWeight, "output_dim" -> numEmbed))
+    var hidden = Symbol.api.Embedding(data = Some(data), input_dim = inputSize,
+      weight = Some(embedWeight), output_dim = numEmbed, name = "embed")
 
     var dpRatio = 0f
     // stack LSTM
     for (i <- 0 until numLstmLayer) {
       if (i == 0) dpRatio = 0f else dpRatio = dropout
       val nextState = lstm(numHidden, inData = hidden,
-                           prevState = lastStates(i),
-                           param = paramCells(i),
-                           seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
+        prevState = lastStates(i),
+        param = paramCells(i),
+        seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
       hidden = nextState.h
       lastStates(i) = nextState
     }
     // decoder
-    if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
-    val fc = Symbol.FullyConnected("pred")()(Map("data" -> hidden, "num_hidden" -> numLabel,
-                                      "weight" -> clsWeight, "bias" -> clsBias))
-    val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> fc))
+    if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
+    val fc = Symbol.api.FullyConnected(data = Some(hidden),
+      num_hidden = numLabel, weight = Some(clsWeight), bias = Some(clsBias))
+    val sm = Symbol.api.SoftmaxOutput(data = Some(fc), name = "softmax")
     var output = Array(sm)
     for (state <- lastStates) {
       output = output :+ state.c
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
index 44ee6e778d2..f7a01bad133 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
@@ -30,9 +30,8 @@ import org.apache.mxnet.module.BucketingModule
 import org.apache.mxnet.module.FitParams
 
 /**
- * Bucketing LSTM examples
- * @author Yizhi Liu
- */
+  * Bucketing LSTM examples
+  */
 class LstmBucketing {
   @Option(name = "--data-train", usage = "training set")
   private val dataTrain: String = "example/rnn/sherlockholmes.train.txt"
@@ -61,6 +60,60 @@ object LstmBucketing {
     Math.exp(loss / labelArr.length).toFloat
   }
 
+  def runTraining(trainData : String, validationData : String,
+                  ctx : Array[Context], numEpoch : Int): Unit = {
+    val batchSize = 32
+    val buckets = Array(10, 20, 30, 40, 50, 60)
+    val numHidden = 200
+    val numEmbed = 200
+    val numLstmLayer = 2
+
+    logger.info("Building vocab ...")
+    val vocab = BucketIo.defaultBuildVocab(trainData)
+
+    def BucketSymGen(key: AnyRef):
+    (Symbol, IndexedSeq[String], IndexedSeq[String]) = {
+      val seqLen = key.asInstanceOf[Int]
+      val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
+        numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
+      (sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
+    }
+
+    val initC = (0 until numLstmLayer).map(l =>
+      (s"l${l}_init_c_beta", (batchSize, numHidden))
+    )
+    val initH = (0 until numLstmLayer).map(l =>
+      (s"l${l}_init_h_beta", (batchSize, numHidden))
+    )
+    val initStates = initC ++ initH
+
+    val dataTrain = new BucketSentenceIter(trainData, vocab,
+      buckets, batchSize, initStates)
+    val dataVal = new BucketSentenceIter(validationData, vocab,
+      buckets, batchSize, initStates)
+
+    val model = new BucketingModule(
+      symGen = BucketSymGen,
+      defaultBucketKey = dataTrain.defaultBucketKey,
+      contexts = ctx)
+
+    val fitParams = new FitParams()
+    fitParams.setEvalMetric(
+      new CustomMetric(perplexity, name = "perplexity"))
+    fitParams.setKVStore("device")
+    fitParams.setOptimizer(
+      new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
+    fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
+    fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))
+
+    logger.info("Start training ...")
+    model.fit(
+      trainData = dataTrain,
+      evalData = Some(dataVal),
+      numEpoch = numEpoch, fitParams)
+    logger.info("Finished training...")
+  }
+
   def main(args: Array[String]): Unit = {
     val inst = new LstmBucketing
     val parser: CmdLineParser = new CmdLineParser(inst)
@@ -71,56 +124,7 @@ object LstmBucketing {
         else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
         else Array(Context.cpu(0))
 
-      val batchSize = 32
-      val buckets = Array(10, 20, 30, 40, 50, 60)
-      val numHidden = 200
-      val numEmbed = 200
-      val numLstmLayer = 2
-
-      logger.info("Building vocab ...")
-      val vocab = BucketIo.defaultBuildVocab(inst.dataTrain)
-
-      def BucketSymGen(key: AnyRef):
-        (Symbol, IndexedSeq[String], IndexedSeq[String]) = {
-        val seqLen = key.asInstanceOf[Int]
-        val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
-          numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
-        (sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
-      }
-
-      val initC = (0 until numLstmLayer).map(l =>
-        (s"l${l}_init_c_beta", (batchSize, numHidden))
-      )
-      val initH = (0 until numLstmLayer).map(l =>
-        (s"l${l}_init_h_beta", (batchSize, numHidden))
-      )
-      val initStates = initC ++ initH
-
-      val dataTrain = new BucketSentenceIter(inst.dataTrain, vocab,
-        buckets, batchSize, initStates)
-      val dataVal = new BucketSentenceIter(inst.dataVal, vocab,
-        buckets, batchSize, initStates)
-
-      val model = new BucketingModule(
-        symGen = BucketSymGen,
-        defaultBucketKey = dataTrain.defaultBucketKey,
-        contexts = contexts)
-
-      val fitParams = new FitParams()
-      fitParams.setEvalMetric(
-        new CustomMetric(perplexity, name = "perplexity"))
-      fitParams.setKVStore("device")
-      fitParams.setOptimizer(
-        new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
-      fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
-      fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))
-
-      logger.info("Start training ...")
-      model.fit(
-        trainData = dataTrain,
-        evalData = Some(dataVal),
-        numEpoch = inst.numEpoch, fitParams)
-      logger.info("Finished training...")
+      runTraining(inst.dataTrain, inst.dataVal, contexts, 5)
     } catch {
       case ex: Exception =>
         logger.error(ex.getMessage, ex)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md
new file mode 100644
index 00000000000..5289fc7b1b4
--- /dev/null
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md
@@ -0,0 +1,48 @@
+# RNN Example for MXNet Scala
+This folder contains the following examples writing in new Scala type-safe API:
+- [x] LSTM Bucketing
+- [x] CharRNN Inference : Generate similar text based on the model
+- [x] CharRNN Training: Training the language model using RNN
+
+These example is only for Illustration and not modeled to achieve the best accuracy.
+
+## Setup
+### Download the Network Definition, Weights and Training Data
+`obama.zip` contains the training inputs (Obama's speech) for CharCNN examples and `sherlockholmes` contains the data for LSTM Bucketing
+```bash
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/obama.zip
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.train.txt
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.valid.txt
+```
+### Unzip the file
+```bash
+unzip obama.zip
+```
+### Arguement Configuration
+Then you need to define the arguments that you would like to pass in the model:
+
+#### LSTM Bucketing
+```bash
+--data-train
+<path>/sherlockholmes.train.txt
+--data-val
+<path>/sherlockholmes.valid.txt
+--cpus
+<num_cpus>
+--gpus
+<num_gpu>
+```
+#### TrainCharRnn
+```bash
+--data-path
+<path>/obama.txt
+--save-model-path
+<path>/
+```
+#### TestCharRnn
+```bash
+--data-path
+<path>/obama.txt
+--model-prefix
+<path>/obama
+```
\ No newline at end of file
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
index 243b70c0670..4786d5d5953 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
@@ -25,66 +25,68 @@ import scala.collection.JavaConverters._
 /**
  * Follows the demo, to test the char rnn:
  * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
- * @author Depeng Liang
  */
 object TestCharRnn {
 
   private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn])
 
-  def main(args: Array[String]): Unit = {
-    val stcr = new TestCharRnn
-    val parser: CmdLineParser = new CmdLineParser(stcr)
-    try {
-      parser.parseArgument(args.toList.asJava)
-      assert(stcr.dataPath != null && stcr.modelPrefix != null && stcr.starterSentence != null)
+  def runTestCharRNN(dataPath: String, modelPrefix: String, starterSentence : String): Unit = {
+    // The batch size for training
+    val batchSize = 32
+    // We can support various length input
+    // For this problem, we cut each input sentence to length of 129
+    // So we only need fix length bucket
+    val buckets = List(129)
+    // hidden unit in LSTM cell
+    val numHidden = 512
+    // embedding dimension, which is, map a char to a 256 dim vector
+    val numEmbed = 256
+    // number of lstm layer
+    val numLstmLayer = 3
 
-      // The batch size for training
-      val batchSize = 32
-      // We can support various length input
-      // For this problem, we cut each input sentence to length of 129
-      // So we only need fix length bucket
-      val buckets = List(129)
-      // hidden unit in LSTM cell
-      val numHidden = 512
-      // embedding dimension, which is, map a char to a 256 dim vector
-      val numEmbed = 256
-      // number of lstm layer
-      val numLstmLayer = 3
+    // build char vocabluary from input
+    val vocab = Utils.buildVocab(dataPath)
 
-      // build char vocabluary from input
-      val vocab = Utils.buildVocab(stcr.dataPath)
+    // load from check-point
+    val (_, argParams, _) = Model.loadCheckpoint(modelPrefix, 75)
 
-      // load from check-point
-      val (_, argParams, _) = Model.loadCheckpoint(stcr.modelPrefix, 75)
+    // build an inference model
+    val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
+      numHidden = numHidden, numEmbed = numEmbed,
+      numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
 
-      // build an inference model
-      val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
-                           numHidden = numHidden, numEmbed = numEmbed,
-                           numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
+    // generate a sequence of 1200 chars
+    val seqLength = 1200
+    val inputNdarray = NDArray.zeros(1)
+    val revertVocab = Utils.makeRevertVocab(vocab)
 
-      // generate a sequence of 1200 chars
-      val seqLength = 1200
-      val inputNdarray = NDArray.zeros(1)
-      val revertVocab = Utils.makeRevertVocab(vocab)
+    // Feel free to change the starter sentence
+    var output = starterSentence
+    val randomSample = true
+    var newSentence = true
+    val ignoreLength = output.length()
 
-      // Feel free to change the starter sentence
-      var output = stcr.starterSentence
-      val randomSample = true
-      var newSentence = true
-      val ignoreLength = output.length()
+    for (i <- 0 until seqLength) {
+      if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray)
+      else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray)
+      val prob = model.forward(inputNdarray, newSentence)
+      newSentence = false
+      val nextChar = Utils.makeOutput(prob, revertVocab, randomSample)
+      if (nextChar == "") newSentence = true
+      if (i >= ignoreLength) output = output ++ nextChar
+    }
 
-      for (i <- 0 until seqLength) {
-        if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray)
-        else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray)
-        val prob = model.forward(inputNdarray, newSentence)
-        newSentence = false
-        val nextChar = Utils.makeOutput(prob, revertVocab, randomSample)
-        if (nextChar == "") newSentence = true
-        if (i >= ignoreLength) output = output ++ nextChar
-      }
+    // Let's see what we can learned from char in Obama's speech.
+    logger.info(output)
+  }
 
-      // Let's see what we can learned from char in Obama's speech.
-      logger.info(output)
+  def main(args: Array[String]): Unit = {
+    val stcr = new TestCharRnn
+    val parser: CmdLineParser = new CmdLineParser(stcr)
+    try {
+      parser.parseArgument(args.toList.asJava)
+      assert(stcr.dataPath != null && stcr.modelPrefix != null && stcr.starterSentence != null)
+      runTestCharRNN(stcr.dataPath, stcr.modelPrefix, stcr.starterSentence)
     } catch {
       case ex: Exception => {
         logger.error(ex.getMessage, ex)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
index 3afb93686b0..fb59705c9ef 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
@@ -24,143 +24,144 @@ import scala.collection.JavaConverters._
 import org.apache.mxnet.optimizer.Adam
 
 /**
- * Follows the demo, to train the char rnn:
- * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
- * @author Depeng Liang
- */
+  * Follows the demo, to train the char rnn:
+  * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
+  */
 object TrainCharRnn {
 
   private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn])
 
-  def main(args: Array[String]): Unit = {
-    val incr = new TrainCharRnn
-    val parser: CmdLineParser = new CmdLineParser(incr)
-    try {
-      parser.parseArgument(args.toList.asJava)
-      assert(incr.dataPath != null && incr.saveModelPath != null)
-
-      // The batch size for training
-      val batchSize = 32
-      // We can support various length input
-      // For this problem, we cut each input sentence to length of 129
-      // So we only need fix length bucket
-      val buckets = Array(129)
-      // hidden unit in LSTM cell
-      val numHidden = 512
-      // embedding dimension, which is, map a char to a 256 dim vector
-      val numEmbed = 256
-      // number of lstm layer
-      val numLstmLayer = 3
-      // we will show a quick demo in 2 epoch
-      // and we will see result by training 75 epoch
-      val numEpoch = 75
-      // learning rate
-      val learningRate = 0.001f
-      // we will use pure sgd without momentum
-      val momentum = 0.0f
-
-      val ctx = if (incr.gpu == -1) Context.cpu() else Context.gpu(incr.gpu)
-      val vocab = Utils.buildVocab(incr.dataPath)
-
-      // generate symbol for a length
-      def symGen(seqLen: Int): Symbol = {
-        Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1,
-                    numHidden = numHidden, numEmbed = numEmbed,
-                    numLabel = vocab.size + 1, dropout = 0.2f)
-      }
+  def runTrainCharRnn(dataPath: String, saveModelPath: String,
+                      ctx : Context, numEpoch : Int): Unit = {
+    // The batch size for training
+    val batchSize = 32
+    // We can support various length input
+    // For this problem, we cut each input sentence to length of 129
+    // So we only need fix length bucket
+    val buckets = Array(129)
+    // hidden unit in LSTM cell
+    val numHidden = 512
+    // embedding dimension, which is, map a char to a 256 dim vector
+    val numEmbed = 256
+    // number of lstm layer
+    val numLstmLayer = 3
+    // we will show a quick demo in 2 epoch
+    // learning rate
+    val learningRate = 0.001f
+    // we will use pure sgd without momentum
+    val momentum = 0.0f
+
+    val vocab = Utils.buildVocab(dataPath)
+
+    // generate symbol for a length
+    def symGen(seqLen: Int): Symbol = {
+      Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1,
+        numHidden = numHidden, numEmbed = numEmbed,
+        numLabel = vocab.size + 1, dropout = 0.2f)
+    }
 
-      // initalize states for LSTM
-      val initC = for (l <- 0 until numLstmLayer)
-        yield (s"l${l}_init_c_beta", (batchSize, numHidden))
-      val initH = for (l <- 0 until numLstmLayer)
-        yield (s"l${l}_init_h_beta", (batchSize, numHidden))
-      val initStates = initC ++ initH
+    // initalize states for LSTM
+    val initC = for (l <- 0 until numLstmLayer)
+      yield (s"l${l}_init_c_beta", (batchSize, numHidden))
+    val initH = for (l <- 0 until numLstmLayer)
+      yield (s"l${l}_init_h_beta", (batchSize, numHidden))
+    val initStates = initC ++ initH
 
-      val dataTrain = new BucketIo.BucketSentenceIter(incr.dataPath, vocab, buckets,
-                                          batchSize, initStates, seperateChar = "\n",
-                                          text2Id = Utils.text2Id, readContent = Utils.readContent)
+    val dataTrain = new BucketIo.BucketSentenceIter(dataPath, vocab, buckets,
+      batchSize, initStates, seperateChar = "\n",
+      text2Id = Utils.text2Id, readContent = Utils.readContent)
 
-      // the network symbol
-      val symbol = symGen(buckets(0))
+    // the network symbol
+    val symbol = symGen(buckets(0))
 
-      val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel
-      val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels)
+    val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel
+    val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels)
 
-      val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
+    val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
 
-      val argNames = symbol.listArguments()
-      val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap
-      val auxNames = symbol.listAuxiliaryStates()
-      val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap
+    val argNames = symbol.listArguments()
+    val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap
+    val auxNames = symbol.listAuxiliaryStates()
+    val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap
 
-      val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
-        !datasAndLabels.contains(name)
-      }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap
+    val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
+      !datasAndLabels.contains(name)
+    }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap
 
-      argDict.foreach { case (name, ndArray) =>
-        if (!datasAndLabels.contains(name)) {
-          initializer.initWeight(name, ndArray)
-        }
+    argDict.foreach { case (name, ndArray) =>
+      if (!datasAndLabels.contains(name)) {
+        initializer.initWeight(name, ndArray)
       }
+    }
 
-      val data = argDict("data")
-      val label = argDict("softmax_label")
+    val data = argDict("data")
+    val label = argDict("softmax_label")
 
-      val executor = symbol.bind(ctx, argDict, gradDict)
+    val executor = symbol.bind(ctx, argDict, gradDict)
 
-      val opt = new Adam(learningRate = learningRate, wd = 0.0001f)
+    val opt = new Adam(learningRate = learningRate, wd = 0.0001f)
 
-      val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
-        (idx, name, grad, opt.createState(idx, argDict(name)))
-      }
+    val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
+      (idx, name, grad, opt.createState(idx, argDict(name)))
+    }
 
-      val evalMetric = new CustomMetric(Utils.perplexity, "perplexity")
-      val batchEndCallback = new Callback.Speedometer(batchSize, 50)
-      val epochEndCallback = Utils.doCheckpoint(s"${incr.saveModelPath}/obama")
-
-      for (epoch <- 0 until numEpoch) {
-        // Training phase
-        val tic = System.currentTimeMillis
-        evalMetric.reset()
-        var nBatch = 0
-        var epochDone = false
-        // Iterate over training data.
-        dataTrain.reset()
-        while (!epochDone) {
-          var doReset = true
-          while (doReset && dataTrain.hasNext) {
-            val dataBatch = dataTrain.next()
-
-            data.set(dataBatch.data(0))
-            label.set(dataBatch.label(0))
-            executor.forward(isTrain = true)
-            executor.backward()
-            paramsGrads.foreach { case (idx, name, grad, optimState) =>
-              opt.update(idx, argDict(name), grad, optimState)
-            }
-
-            // evaluate at end, so out_cpu_array can lazy copy
-            evalMetric.update(dataBatch.label, executor.outputs)
-
-            nBatch += 1
-            batchEndCallback.invoke(epoch, nBatch, evalMetric)
+    val evalMetric = new CustomMetric(Utils.perplexity, "perplexity")
+    val batchEndCallback = new Callback.Speedometer(batchSize, 50)
+    val epochEndCallback = Utils.doCheckpoint(s"${saveModelPath}/obama")
+
+    for (epoch <- 0 until numEpoch) {
+      // Training phase
+      val tic = System.currentTimeMillis
+      evalMetric.reset()
+      var nBatch = 0
+      var epochDone = false
+      // Iterate over training data.
+      dataTrain.reset()
+      while (!epochDone) {
+        var doReset = true
+        while (doReset && dataTrain.hasNext) {
+          val dataBatch = dataTrain.next()
+
+          data.set(dataBatch.data(0))
+          label.set(dataBatch.label(0))
+          executor.forward(isTrain = true)
+          executor.backward()
+          paramsGrads.foreach { case (idx, name, grad, optimState) =>
+            opt.update(idx, argDict(name), grad, optimState)
           }
-          if (doReset) {
-            dataTrain.reset()
-          }
-          // this epoch is done
-          epochDone = true
+
+          // evaluate at end, so out_cpu_array can lazy copy
+          evalMetric.update(dataBatch.label, executor.outputs)
+
+          nBatch += 1
+          batchEndCallback.invoke(epoch, nBatch, evalMetric)
         }
-        val (name, value) = evalMetric.get
-        name.zip(value).foreach { case (n, v) =>
-          logger.info(s"Epoch[$epoch] Train-$n=$v")
+        if (doReset) {
+          dataTrain.reset()
         }
-        val toc = System.currentTimeMillis
-        logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
-
-        epochEndCallback.invoke(epoch, symbol, argDict, auxDict)
+        // this epoch is done
+        epochDone = true
       }
-      executor.dispose()
+      val (name, value) = evalMetric.get
+      name.zip(value).foreach { case (n, v) =>
+        logger.info(s"Epoch[$epoch] Train-$n=$v")
+      }
+      val toc = System.currentTimeMillis
+      logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
+
+      epochEndCallback.invoke(epoch, symbol, argDict, auxDict)
+    }
+    executor.dispose()
+  }
+
+  def main(args: Array[String]): Unit = {
+    val incr = new TrainCharRnn
+    val parser: CmdLineParser = new CmdLineParser(incr)
+    try {
+      parser.parseArgument(args.toList.asJava)
+      val ctx = if (incr.gpu == -1) Context.cpu() else Context.gpu(incr.gpu)
+      assert(incr.dataPath != null && incr.saveModelPath != null)
+      runTrainCharRnn(incr.dataPath, incr.saveModelPath, ctx, 75)
     } catch {
       case ex: Exception => {
         logger.error(ex.getMessage, ex)
@@ -172,12 +173,6 @@ object TrainCharRnn {
 }
 
 class TrainCharRnn {
-  /*
-   * Get Training Data:  E.g.
-   * mkdir data; cd data
-   * wget "http://data.mxnet.io/mxnet/data/char_lstm.zip"
-   * unzip -o char_lstm.zip
-   */
   @Option(name = "--data-path", usage = "the input train data file")
   private val dataPath: String = "./data/obama.txt"
   @Option(name = "--save-model-path", usage = "the model saving path")
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala
index c2902309679..3f9a9842e0a 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala
@@ -25,9 +25,6 @@ import org.apache.mxnet.Model
 import org.apache.mxnet.Symbol
 import scala.util.Random
 
-/**
- * @author Depeng Liang
- */
 object Utils {
 
   def readContent(path: String): String = Source.fromFile(path).mkString
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
new file mode 100644
index 00000000000..b393a433305
--- /dev/null
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.rnn
+
+
+import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnetexamples.Util
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
+import org.slf4j.LoggerFactory
+
+import scala.sys.process.Process
+
+@Ignore
+class ExampleRNNSuite extends FunSuite with BeforeAndAfterAll {
+  private val logger = LoggerFactory.getLogger(classOf[ExampleRNNSuite])
+
+  override def beforeAll(): Unit = {
+    logger.info("Downloading LSTM model")
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    logger.info("tempDirPath: %s".format(tempDirPath))
+    val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/"
+    Util.downloadUrl(baseUrl + "obama.zip", tempDirPath + "/RNN/obama.zip")
+    Util.downloadUrl(baseUrl + "sherlockholmes.train.txt",
+      tempDirPath + "/RNN/sherlockholmes.train.txt")
+    Util.downloadUrl(baseUrl + "sherlockholmes.valid.txt",
+      tempDirPath + "/RNN/sherlockholmes.valid.txt")
+    // TODO: Need to confirm with Windows
+    Process(s"unzip $tempDirPath/RNN/obama.zip -d $tempDirPath/RNN/") !
+  }
+
+  test("Example CI: Test LSTM Bucketing") {
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    var ctx = Context.cpu()
+    if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+      System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+      ctx = Context.gpu()
+    }
+    LstmBucketing.runTraining(tempDirPath + "/RNN/sherlockholmes.train.txt",
+        tempDirPath + "/RNN/sherlockholmes.valid.txt", Array(ctx), 1)
+  }
+
+  test("Example CI: Test TrainCharRNN") {
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+      System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+      val ctx = Context.gpu()
+      TrainCharRnn.runTrainCharRnn(tempDirPath + "/RNN/obama.txt",
+          tempDirPath, ctx, 1)
+    } else {
+      logger.info("CPU not supported for this test, skipped...")
+    }
+  }
+
+  test("Example CI: Test TestCharRNN") {
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    val ctx = Context.gpu()
+    TestCharRnn.runTestCharRNN(tempDirPath + "/RNN/obama.txt",
+        tempDirPath + "/RNN/obama", "The joke")
+  }
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services