You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/08/21 18:09:42 UTC
[incubator-mxnet] branch master updated: [MXNET-836] RNN Example
for Scala (#11753)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 38f80af [MXNET-836] RNN Example for Scala (#11753)
38f80af is described below
commit 38f80af6a4ac1ec1760772d5a407c39c876c16e6
Author: Lanking <la...@live.com>
AuthorDate: Tue Aug 21 11:09:23 2018 -0700
[MXNET-836] RNN Example for Scala (#11753)
* initial fix for RNN
* add CI test
* add encoding format
* scala style fix
* update readme
* test char RNN works
* ignore the test due to memory leaks
---
.../org/apache/mxnetexamples/rnn/BucketIo.scala | 19 +-
.../scala/org/apache/mxnetexamples/rnn/Lstm.scala | 97 ++++-----
.../apache/mxnetexamples/rnn/LstmBucketing.scala | 110 +++++-----
.../scala/org/apache/mxnetexamples/rnn/README.md | 48 +++++
.../org/apache/mxnetexamples/rnn/TestCharRnn.scala | 96 +++++----
.../apache/mxnetexamples/rnn/TrainCharRnn.scala | 237 ++++++++++-----------
.../scala/org/apache/mxnetexamples/rnn/Utils.scala | 3 -
.../apache/mxnetexamples/rnn/ExampleRNNSuite.scala | 75 +++++++
8 files changed, 399 insertions(+), 286 deletions(-)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
index d4b1707..6d414bb 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
@@ -34,7 +34,7 @@ object BucketIo {
type ReadContent = String => String
def defaultReadContent(path: String): String = {
- Source.fromFile(path).mkString.replaceAll("\\. |\n", " <eos> ")
+ Source.fromFile(path, "UTF-8").mkString.replaceAll("\\. |\n", " <eos> ")
}
def defaultBuildVocab(path: String): Map[String, Int] = {
@@ -56,7 +56,7 @@ object BucketIo {
val tmp = sentence.split(" ").filter(_.length() > 0)
for (w <- tmp) yield theVocab(w)
}
- words.toArray
+ words
}
def defaultGenBuckets(sentences: Array[String], batchSize: Int,
@@ -162,8 +162,6 @@ object BucketIo {
labelBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket)))
}
- private val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))
-
private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey))
tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
}
@@ -208,12 +206,13 @@ object BucketIo {
tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
}
val batchProvideLabel = ListMap("softmax_label" -> labelBuf.shape)
- new DataBatch(IndexedSeq(dataBuf) ++ initStateArrays,
- IndexedSeq(labelBuf),
- getIndex(),
- getPad(),
- this.buckets(bucketIdx).asInstanceOf[AnyRef],
- batchProvideData, batchProvideLabel)
+ val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))
+ new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays,
+ IndexedSeq(labelBuf.copy()),
+ getIndex(),
+ getPad(),
+ this.buckets(bucketIdx).asInstanceOf[AnyRef],
+ batchProvideData, batchProvideLabel)
}
/**
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala
index bf29a47..872ef78 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Lstm.scala
@@ -18,13 +18,10 @@
package org.apache.mxnetexamples.rnn
-import org.apache.mxnet.Symbol
+import org.apache.mxnet.{Shape, Symbol}
import scala.collection.mutable.ArrayBuffer
-/**
- * @author Depeng Liang
- */
object Lstm {
final case class LSTMState(c: Symbol, h: Symbol)
@@ -35,27 +32,22 @@ object Lstm {
def lstm(numHidden: Int, inData: Symbol, prevState: LSTMState,
param: LSTMParam, seqIdx: Int, layerIdx: Int, dropout: Float = 0f): LSTMState = {
val inDataa = {
- if (dropout > 0f) Symbol.Dropout()()(Map("data" -> inData, "p" -> dropout))
+ if (dropout > 0f) Symbol.api.Dropout(data = Some(inData), p = Some(dropout))
else inData
}
- val i2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_i2h")()(Map("data" -> inDataa,
- "weight" -> param.i2hWeight,
- "bias" -> param.i2hBias,
- "num_hidden" -> numHidden * 4))
- val h2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_h2h")()(Map("data" -> prevState.h,
- "weight" -> param.h2hWeight,
- "bias" -> param.h2hBias,
- "num_hidden" -> numHidden * 4))
+ val i2h = Symbol.api.FullyConnected(data = Some(inDataa), weight = Some(param.i2hWeight),
+ bias = Some(param.i2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_i2h")
+ val h2h = Symbol.api.FullyConnected(data = Some(prevState.h), weight = Some(param.h2hWeight),
+ bias = Some(param.h2hBias), num_hidden = numHidden * 4, name = s"t${seqIdx}_l${layerIdx}_h2h")
val gates = i2h + h2h
- val sliceGates = Symbol.SliceChannel(s"t${seqIdx}_l${layerIdx}_slice")(
- gates)(Map("num_outputs" -> 4))
- val ingate = Symbol.Activation()()(Map("data" -> sliceGates.get(0), "act_type" -> "sigmoid"))
- val inTransform = Symbol.Activation()()(Map("data" -> sliceGates.get(1), "act_type" -> "tanh"))
- val forgetGate = Symbol.Activation()()(
- Map("data" -> sliceGates.get(2), "act_type" -> "sigmoid"))
- val outGate = Symbol.Activation()()(Map("data" -> sliceGates.get(3), "act_type" -> "sigmoid"))
+ val sliceGates = Symbol.api.SliceChannel(data = Some(gates), num_outputs = 4,
+ name = s"t${seqIdx}_l${layerIdx}_slice")
+ val ingate = Symbol.api.Activation(data = Some(sliceGates.get(0)), act_type = "sigmoid")
+ val inTransform = Symbol.api.Activation(data = Some(sliceGates.get(1)), act_type = "tanh")
+ val forgetGate = Symbol.api.Activation(data = Some(sliceGates.get(2)), act_type = "sigmoid")
+ val outGate = Symbol.api.Activation(data = Some(sliceGates.get(3)), act_type = "sigmoid")
val nextC = (forgetGate * prevState.c) + (ingate * inTransform)
- val nextH = outGate * Symbol.Activation()()(Map("data" -> nextC, "act_type" -> "tanh"))
+ val nextH = outGate * Symbol.api.Activation(data = Some(nextC), "tanh")
LSTMState(c = nextC, h = nextH)
}
@@ -74,11 +66,11 @@ object Lstm {
val lastStatesBuf = ArrayBuffer[LSTMState]()
for (i <- 0 until numLstmLayer) {
paramCellsBuf.append(LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
- i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
- h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
- h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
+ i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
+ h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
+ h2hBias = Symbol.Variable(s"l${i}_h2h_bias")))
lastStatesBuf.append(LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
- h = Symbol.Variable(s"l${i}_init_h_beta")))
+ h = Symbol.Variable(s"l${i}_init_h_beta")))
}
val paramCells = paramCellsBuf.toArray
val lastStates = lastStatesBuf.toArray
@@ -87,10 +79,10 @@ object Lstm {
// embeding layer
val data = Symbol.Variable("data")
var label = Symbol.Variable("softmax_label")
- val embed = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
- "weight" -> embedWeight, "output_dim" -> numEmbed))
- val wordvec = Symbol.SliceChannel()()(
- Map("data" -> embed, "num_outputs" -> seqLen, "squeeze_axis" -> 1))
+ val embed = Symbol.api.Embedding(data = Some(data), input_dim = inputSize,
+ weight = Some(embedWeight), output_dim = numEmbed, name = "embed")
+ val wordvec = Symbol.api.SliceChannel(data = Some(embed),
+ num_outputs = seqLen, squeeze_axis = Some(true))
val hiddenAll = ArrayBuffer[Symbol]()
var dpRatio = 0f
@@ -101,22 +93,23 @@ object Lstm {
for (i <- 0 until numLstmLayer) {
if (i == 0) dpRatio = 0f else dpRatio = dropout
val nextState = lstm(numHidden, inData = hidden,
- prevState = lastStates(i),
- param = paramCells(i),
- seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
+ prevState = lastStates(i),
+ param = paramCells(i),
+ seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
hidden = nextState.h
lastStates(i) = nextState
}
// decoder
- if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
+ if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
hiddenAll.append(hidden)
}
- val hiddenConcat = Symbol.Concat()(hiddenAll: _*)(Map("dim" -> 0))
- val pred = Symbol.FullyConnected("pred")()(Map("data" -> hiddenConcat, "num_hidden" -> numLabel,
- "weight" -> clsWeight, "bias" -> clsBias))
- label = Symbol.transpose()(label)()
- label = Symbol.Reshape()()(Map("data" -> label, "target_shape" -> "(0,)"))
- val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> pred, "label" -> label))
+ val hiddenConcat = Symbol.api.Concat(data = hiddenAll.toArray, num_args = hiddenAll.length,
+ dim = Some(0))
+ val pred = Symbol.api.FullyConnected(data = Some(hiddenConcat), num_hidden = numLabel,
+ weight = Some(clsWeight), bias = Some(clsBias))
+ label = Symbol.api.transpose(data = Some(label))
+ label = Symbol.api.Reshape(data = Some(label), target_shape = Some(Shape(0)))
+ val sm = Symbol.api.SoftmaxOutput(data = Some(pred), label = Some(label), name = "softmax")
sm
}
@@ -131,35 +124,35 @@ object Lstm {
var lastStates = Array[LSTMState]()
for (i <- 0 until numLstmLayer) {
paramCells = paramCells :+ LSTMParam(i2hWeight = Symbol.Variable(s"l${i}_i2h_weight"),
- i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
- h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
- h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
+ i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
+ h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
+ h2hBias = Symbol.Variable(s"l${i}_h2h_bias"))
lastStates = lastStates :+ LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
- h = Symbol.Variable(s"l${i}_init_h_beta"))
+ h = Symbol.Variable(s"l${i}_init_h_beta"))
}
assert(lastStates.length == numLstmLayer)
val data = Symbol.Variable("data")
- var hidden = Symbol.Embedding("embed")()(Map("data" -> data, "input_dim" -> inputSize,
- "weight" -> embedWeight, "output_dim" -> numEmbed))
+ var hidden = Symbol.api.Embedding(data = Some(data), input_dim = inputSize,
+ weight = Some(embedWeight), output_dim = numEmbed, name = "embed")
var dpRatio = 0f
// stack LSTM
for (i <- 0 until numLstmLayer) {
if (i == 0) dpRatio = 0f else dpRatio = dropout
val nextState = lstm(numHidden, inData = hidden,
- prevState = lastStates(i),
- param = paramCells(i),
- seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
+ prevState = lastStates(i),
+ param = paramCells(i),
+ seqIdx = seqIdx, layerIdx = i, dropout = dpRatio)
hidden = nextState.h
lastStates(i) = nextState
}
// decoder
- if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout))
- val fc = Symbol.FullyConnected("pred")()(Map("data" -> hidden, "num_hidden" -> numLabel,
- "weight" -> clsWeight, "bias" -> clsBias))
- val sm = Symbol.SoftmaxOutput("softmax")()(Map("data" -> fc))
+ if (dropout > 0f) hidden = Symbol.api.Dropout(data = Some(hidden), p = Some(dropout))
+ val fc = Symbol.api.FullyConnected(data = Some(hidden),
+ num_hidden = numLabel, weight = Some(clsWeight), bias = Some(clsBias))
+ val sm = Symbol.api.SoftmaxOutput(data = Some(fc), name = "softmax")
var output = Array(sm)
for (state <- lastStates) {
output = output :+ state.c
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
index 44ee6e7..f7a01ba 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
@@ -30,9 +30,8 @@ import org.apache.mxnet.module.BucketingModule
import org.apache.mxnet.module.FitParams
/**
- * Bucketing LSTM examples
- * @author Yizhi Liu
- */
+ * Bucketing LSTM examples
+ */
class LstmBucketing {
@Option(name = "--data-train", usage = "training set")
private val dataTrain: String = "example/rnn/sherlockholmes.train.txt"
@@ -61,6 +60,60 @@ object LstmBucketing {
Math.exp(loss / labelArr.length).toFloat
}
+ def runTraining(trainData : String, validationData : String,
+ ctx : Array[Context], numEpoch : Int): Unit = {
+ val batchSize = 32
+ val buckets = Array(10, 20, 30, 40, 50, 60)
+ val numHidden = 200
+ val numEmbed = 200
+ val numLstmLayer = 2
+
+ logger.info("Building vocab ...")
+ val vocab = BucketIo.defaultBuildVocab(trainData)
+
+ def BucketSymGen(key: AnyRef):
+ (Symbol, IndexedSeq[String], IndexedSeq[String]) = {
+ val seqLen = key.asInstanceOf[Int]
+ val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
+ numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
+ (sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
+ }
+
+ val initC = (0 until numLstmLayer).map(l =>
+ (s"l${l}_init_c_beta", (batchSize, numHidden))
+ )
+ val initH = (0 until numLstmLayer).map(l =>
+ (s"l${l}_init_h_beta", (batchSize, numHidden))
+ )
+ val initStates = initC ++ initH
+
+ val dataTrain = new BucketSentenceIter(trainData, vocab,
+ buckets, batchSize, initStates)
+ val dataVal = new BucketSentenceIter(validationData, vocab,
+ buckets, batchSize, initStates)
+
+ val model = new BucketingModule(
+ symGen = BucketSymGen,
+ defaultBucketKey = dataTrain.defaultBucketKey,
+ contexts = ctx)
+
+ val fitParams = new FitParams()
+ fitParams.setEvalMetric(
+ new CustomMetric(perplexity, name = "perplexity"))
+ fitParams.setKVStore("device")
+ fitParams.setOptimizer(
+ new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
+ fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
+ fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))
+
+ logger.info("Start training ...")
+ model.fit(
+ trainData = dataTrain,
+ evalData = Some(dataVal),
+ numEpoch = numEpoch, fitParams)
+ logger.info("Finished training...")
+ }
+
def main(args: Array[String]): Unit = {
val inst = new LstmBucketing
val parser: CmdLineParser = new CmdLineParser(inst)
@@ -71,56 +124,7 @@ object LstmBucketing {
else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
else Array(Context.cpu(0))
- val batchSize = 32
- val buckets = Array(10, 20, 30, 40, 50, 60)
- val numHidden = 200
- val numEmbed = 200
- val numLstmLayer = 2
-
- logger.info("Building vocab ...")
- val vocab = BucketIo.defaultBuildVocab(inst.dataTrain)
-
- def BucketSymGen(key: AnyRef):
- (Symbol, IndexedSeq[String], IndexedSeq[String]) = {
- val seqLen = key.asInstanceOf[Int]
- val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
- numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
- (sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
- }
-
- val initC = (0 until numLstmLayer).map(l =>
- (s"l${l}_init_c_beta", (batchSize, numHidden))
- )
- val initH = (0 until numLstmLayer).map(l =>
- (s"l${l}_init_h_beta", (batchSize, numHidden))
- )
- val initStates = initC ++ initH
-
- val dataTrain = new BucketSentenceIter(inst.dataTrain, vocab,
- buckets, batchSize, initStates)
- val dataVal = new BucketSentenceIter(inst.dataVal, vocab,
- buckets, batchSize, initStates)
-
- val model = new BucketingModule(
- symGen = BucketSymGen,
- defaultBucketKey = dataTrain.defaultBucketKey,
- contexts = contexts)
-
- val fitParams = new FitParams()
- fitParams.setEvalMetric(
- new CustomMetric(perplexity, name = "perplexity"))
- fitParams.setKVStore("device")
- fitParams.setOptimizer(
- new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
- fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
- fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))
-
- logger.info("Start training ...")
- model.fit(
- trainData = dataTrain,
- evalData = Some(dataVal),
- numEpoch = inst.numEpoch, fitParams)
- logger.info("Finished training...")
+ runTraining(inst.dataTrain, inst.dataVal, contexts, 5)
} catch {
case ex: Exception =>
logger.error(ex.getMessage, ex)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md
new file mode 100644
index 0000000..5289fc7
--- /dev/null
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/README.md
@@ -0,0 +1,48 @@
+# RNN Example for MXNet Scala
+This folder contains the following examples writing in new Scala type-safe API:
+- [x] LSTM Bucketing
+- [x] CharRNN Inference : Generate similar text based on the model
+- [x] CharRNN Training: Training the language model using RNN
+
+These example is only for Illustration and not modeled to achieve the best accuracy.
+
+## Setup
+### Download the Network Definition, Weights and Training Data
+`obama.zip` contains the training inputs (Obama's speech) for CharCNN examples and `sherlockholmes` contains the data for LSTM Bucketing
+```bash
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/obama.zip
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.train.txt
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/sherlockholmes.valid.txt
+```
+### Unzip the file
+```bash
+unzip obama.zip
+```
+### Arguement Configuration
+Then you need to define the arguments that you would like to pass in the model:
+
+#### LSTM Bucketing
+```bash
+--data-train
+<path>/sherlockholmes.train.txt
+--data-val
+<path>/sherlockholmes.valid.txt
+--cpus
+<num_cpus>
+--gpus
+<num_gpu>
+```
+#### TrainCharRnn
+```bash
+--data-path
+<path>/obama.txt
+--save-model-path
+<path>/
+```
+#### TestCharRnn
+```bash
+--data-path
+<path>/obama.txt
+--model-prefix
+<path>/obama
+```
\ No newline at end of file
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
index 243b70c..4786d5d 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
@@ -25,66 +25,68 @@ import scala.collection.JavaConverters._
/**
* Follows the demo, to test the char rnn:
* https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
- * @author Depeng Liang
*/
object TestCharRnn {
private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn])
- def main(args: Array[String]): Unit = {
- val stcr = new TestCharRnn
- val parser: CmdLineParser = new CmdLineParser(stcr)
- try {
- parser.parseArgument(args.toList.asJava)
- assert(stcr.dataPath != null && stcr.modelPrefix != null && stcr.starterSentence != null)
+ def runTestCharRNN(dataPath: String, modelPrefix: String, starterSentence : String): Unit = {
+ // The batch size for training
+ val batchSize = 32
+ // We can support various length input
+ // For this problem, we cut each input sentence to length of 129
+ // So we only need fix length bucket
+ val buckets = List(129)
+ // hidden unit in LSTM cell
+ val numHidden = 512
+ // embedding dimension, which is, map a char to a 256 dim vector
+ val numEmbed = 256
+ // number of lstm layer
+ val numLstmLayer = 3
- // The batch size for training
- val batchSize = 32
- // We can support various length input
- // For this problem, we cut each input sentence to length of 129
- // So we only need fix length bucket
- val buckets = List(129)
- // hidden unit in LSTM cell
- val numHidden = 512
- // embedding dimension, which is, map a char to a 256 dim vector
- val numEmbed = 256
- // number of lstm layer
- val numLstmLayer = 3
+ // build char vocabluary from input
+ val vocab = Utils.buildVocab(dataPath)
- // build char vocabluary from input
- val vocab = Utils.buildVocab(stcr.dataPath)
+ // load from check-point
+ val (_, argParams, _) = Model.loadCheckpoint(modelPrefix, 75)
- // load from check-point
- val (_, argParams, _) = Model.loadCheckpoint(stcr.modelPrefix, 75)
+ // build an inference model
+ val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
+ numHidden = numHidden, numEmbed = numEmbed,
+ numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
- // build an inference model
- val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
- numHidden = numHidden, numEmbed = numEmbed,
- numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
+ // generate a sequence of 1200 chars
+ val seqLength = 1200
+ val inputNdarray = NDArray.zeros(1)
+ val revertVocab = Utils.makeRevertVocab(vocab)
- // generate a sequence of 1200 chars
- val seqLength = 1200
- val inputNdarray = NDArray.zeros(1)
- val revertVocab = Utils.makeRevertVocab(vocab)
+ // Feel free to change the starter sentence
+ var output = starterSentence
+ val randomSample = true
+ var newSentence = true
+ val ignoreLength = output.length()
- // Feel free to change the starter sentence
- var output = stcr.starterSentence
- val randomSample = true
- var newSentence = true
- val ignoreLength = output.length()
+ for (i <- 0 until seqLength) {
+ if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray)
+ else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray)
+ val prob = model.forward(inputNdarray, newSentence)
+ newSentence = false
+ val nextChar = Utils.makeOutput(prob, revertVocab, randomSample)
+ if (nextChar == "") newSentence = true
+ if (i >= ignoreLength) output = output ++ nextChar
+ }
- for (i <- 0 until seqLength) {
- if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray)
- else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray)
- val prob = model.forward(inputNdarray, newSentence)
- newSentence = false
- val nextChar = Utils.makeOutput(prob, revertVocab, randomSample)
- if (nextChar == "") newSentence = true
- if (i >= ignoreLength) output = output ++ nextChar
- }
+ // Let's see what we can learned from char in Obama's speech.
+ logger.info(output)
+ }
- // Let's see what we can learned from char in Obama's speech.
- logger.info(output)
+ def main(args: Array[String]): Unit = {
+ val stcr = new TestCharRnn
+ val parser: CmdLineParser = new CmdLineParser(stcr)
+ try {
+ parser.parseArgument(args.toList.asJava)
+ assert(stcr.dataPath != null && stcr.modelPrefix != null && stcr.starterSentence != null)
+ runTestCharRNN(stcr.dataPath, stcr.modelPrefix, stcr.starterSentence)
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
index 3afb936..fb59705 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
@@ -24,143 +24,144 @@ import scala.collection.JavaConverters._
import org.apache.mxnet.optimizer.Adam
/**
- * Follows the demo, to train the char rnn:
- * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
- * @author Depeng Liang
- */
+ * Follows the demo, to train the char rnn:
+ * https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
+ */
object TrainCharRnn {
private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn])
- def main(args: Array[String]): Unit = {
- val incr = new TrainCharRnn
- val parser: CmdLineParser = new CmdLineParser(incr)
- try {
- parser.parseArgument(args.toList.asJava)
- assert(incr.dataPath != null && incr.saveModelPath != null)
-
- // The batch size for training
- val batchSize = 32
- // We can support various length input
- // For this problem, we cut each input sentence to length of 129
- // So we only need fix length bucket
- val buckets = Array(129)
- // hidden unit in LSTM cell
- val numHidden = 512
- // embedding dimension, which is, map a char to a 256 dim vector
- val numEmbed = 256
- // number of lstm layer
- val numLstmLayer = 3
- // we will show a quick demo in 2 epoch
- // and we will see result by training 75 epoch
- val numEpoch = 75
- // learning rate
- val learningRate = 0.001f
- // we will use pure sgd without momentum
- val momentum = 0.0f
-
- val ctx = if (incr.gpu == -1) Context.cpu() else Context.gpu(incr.gpu)
- val vocab = Utils.buildVocab(incr.dataPath)
-
- // generate symbol for a length
- def symGen(seqLen: Int): Symbol = {
- Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1,
- numHidden = numHidden, numEmbed = numEmbed,
- numLabel = vocab.size + 1, dropout = 0.2f)
- }
+ def runTrainCharRnn(dataPath: String, saveModelPath: String,
+ ctx : Context, numEpoch : Int): Unit = {
+ // The batch size for training
+ val batchSize = 32
+ // We can support various length input
+ // For this problem, we cut each input sentence to length of 129
+ // So we only need fix length bucket
+ val buckets = Array(129)
+ // hidden unit in LSTM cell
+ val numHidden = 512
+ // embedding dimension, which is, map a char to a 256 dim vector
+ val numEmbed = 256
+ // number of lstm layer
+ val numLstmLayer = 3
+ // we will show a quick demo in 2 epoch
+ // learning rate
+ val learningRate = 0.001f
+ // we will use pure sgd without momentum
+ val momentum = 0.0f
+
+ val vocab = Utils.buildVocab(dataPath)
+
+ // generate symbol for a length
+ def symGen(seqLen: Int): Symbol = {
+ Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1,
+ numHidden = numHidden, numEmbed = numEmbed,
+ numLabel = vocab.size + 1, dropout = 0.2f)
+ }
- // initalize states for LSTM
- val initC = for (l <- 0 until numLstmLayer)
- yield (s"l${l}_init_c_beta", (batchSize, numHidden))
- val initH = for (l <- 0 until numLstmLayer)
- yield (s"l${l}_init_h_beta", (batchSize, numHidden))
- val initStates = initC ++ initH
+ // initalize states for LSTM
+ val initC = for (l <- 0 until numLstmLayer)
+ yield (s"l${l}_init_c_beta", (batchSize, numHidden))
+ val initH = for (l <- 0 until numLstmLayer)
+ yield (s"l${l}_init_h_beta", (batchSize, numHidden))
+ val initStates = initC ++ initH
- val dataTrain = new BucketIo.BucketSentenceIter(incr.dataPath, vocab, buckets,
- batchSize, initStates, seperateChar = "\n",
- text2Id = Utils.text2Id, readContent = Utils.readContent)
+ val dataTrain = new BucketIo.BucketSentenceIter(dataPath, vocab, buckets,
+ batchSize, initStates, seperateChar = "\n",
+ text2Id = Utils.text2Id, readContent = Utils.readContent)
- // the network symbol
- val symbol = symGen(buckets(0))
+ // the network symbol
+ val symbol = symGen(buckets(0))
- val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel
- val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels)
+ val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel
+ val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels)
- val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
+ val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
- val argNames = symbol.listArguments()
- val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap
- val auxNames = symbol.listAuxiliaryStates()
- val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap
+ val argNames = symbol.listArguments()
+ val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap
+ val auxNames = symbol.listAuxiliaryStates()
+ val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap
- val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
- !datasAndLabels.contains(name)
- }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap
+ val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
+ !datasAndLabels.contains(name)
+ }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap
- argDict.foreach { case (name, ndArray) =>
- if (!datasAndLabels.contains(name)) {
- initializer.initWeight(name, ndArray)
- }
+ argDict.foreach { case (name, ndArray) =>
+ if (!datasAndLabels.contains(name)) {
+ initializer.initWeight(name, ndArray)
}
+ }
- val data = argDict("data")
- val label = argDict("softmax_label")
+ val data = argDict("data")
+ val label = argDict("softmax_label")
- val executor = symbol.bind(ctx, argDict, gradDict)
+ val executor = symbol.bind(ctx, argDict, gradDict)
- val opt = new Adam(learningRate = learningRate, wd = 0.0001f)
+ val opt = new Adam(learningRate = learningRate, wd = 0.0001f)
- val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
- (idx, name, grad, opt.createState(idx, argDict(name)))
- }
+ val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
+ (idx, name, grad, opt.createState(idx, argDict(name)))
+ }
- val evalMetric = new CustomMetric(Utils.perplexity, "perplexity")
- val batchEndCallback = new Callback.Speedometer(batchSize, 50)
- val epochEndCallback = Utils.doCheckpoint(s"${incr.saveModelPath}/obama")
-
- for (epoch <- 0 until numEpoch) {
- // Training phase
- val tic = System.currentTimeMillis
- evalMetric.reset()
- var nBatch = 0
- var epochDone = false
- // Iterate over training data.
- dataTrain.reset()
- while (!epochDone) {
- var doReset = true
- while (doReset && dataTrain.hasNext) {
- val dataBatch = dataTrain.next()
-
- data.set(dataBatch.data(0))
- label.set(dataBatch.label(0))
- executor.forward(isTrain = true)
- executor.backward()
- paramsGrads.foreach { case (idx, name, grad, optimState) =>
- opt.update(idx, argDict(name), grad, optimState)
- }
-
- // evaluate at end, so out_cpu_array can lazy copy
- evalMetric.update(dataBatch.label, executor.outputs)
-
- nBatch += 1
- batchEndCallback.invoke(epoch, nBatch, evalMetric)
+ val evalMetric = new CustomMetric(Utils.perplexity, "perplexity")
+ val batchEndCallback = new Callback.Speedometer(batchSize, 50)
+ val epochEndCallback = Utils.doCheckpoint(s"${saveModelPath}/obama")
+
+ for (epoch <- 0 until numEpoch) {
+ // Training phase
+ val tic = System.currentTimeMillis
+ evalMetric.reset()
+ var nBatch = 0
+ var epochDone = false
+ // Iterate over training data.
+ dataTrain.reset()
+ while (!epochDone) {
+ var doReset = true
+ while (doReset && dataTrain.hasNext) {
+ val dataBatch = dataTrain.next()
+
+ data.set(dataBatch.data(0))
+ label.set(dataBatch.label(0))
+ executor.forward(isTrain = true)
+ executor.backward()
+ paramsGrads.foreach { case (idx, name, grad, optimState) =>
+ opt.update(idx, argDict(name), grad, optimState)
}
- if (doReset) {
- dataTrain.reset()
- }
- // this epoch is done
- epochDone = true
+
+ // evaluate at end, so out_cpu_array can lazy copy
+ evalMetric.update(dataBatch.label, executor.outputs)
+
+ nBatch += 1
+ batchEndCallback.invoke(epoch, nBatch, evalMetric)
}
- val (name, value) = evalMetric.get
- name.zip(value).foreach { case (n, v) =>
- logger.info(s"Epoch[$epoch] Train-$n=$v")
+ if (doReset) {
+ dataTrain.reset()
}
- val toc = System.currentTimeMillis
- logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
-
- epochEndCallback.invoke(epoch, symbol, argDict, auxDict)
+ // this epoch is done
+ epochDone = true
}
- executor.dispose()
+ val (name, value) = evalMetric.get
+ name.zip(value).foreach { case (n, v) =>
+ logger.info(s"Epoch[$epoch] Train-$n=$v")
+ }
+ val toc = System.currentTimeMillis
+ logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
+
+ epochEndCallback.invoke(epoch, symbol, argDict, auxDict)
+ }
+ executor.dispose()
+ }
+
+ def main(args: Array[String]): Unit = {
+ val incr = new TrainCharRnn
+ val parser: CmdLineParser = new CmdLineParser(incr)
+ try {
+ parser.parseArgument(args.toList.asJava)
+ val ctx = if (incr.gpu == -1) Context.cpu() else Context.gpu(incr.gpu)
+ assert(incr.dataPath != null && incr.saveModelPath != null)
+ runTrainCharRnn(incr.dataPath, incr.saveModelPath, ctx, 75)
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
@@ -172,12 +173,6 @@ object TrainCharRnn {
}
class TrainCharRnn {
- /*
- * Get Training Data: E.g.
- * mkdir data; cd data
- * wget "http://data.mxnet.io/mxnet/data/char_lstm.zip"
- * unzip -o char_lstm.zip
- */
@Option(name = "--data-path", usage = "the input train data file")
private val dataPath: String = "./data/obama.txt"
@Option(name = "--save-model-path", usage = "the model saving path")
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala
index c290230..3f9a984 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/Utils.scala
@@ -25,9 +25,6 @@ import org.apache.mxnet.Model
import org.apache.mxnet.Symbol
import scala.util.Random
-/**
- * @author Depeng Liang
- */
object Utils {
def readContent(path: String): String = Source.fromFile(path).mkString
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
new file mode 100644
index 0000000..b393a43
--- /dev/null
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.rnn
+
+
+import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnetexamples.Util
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
+import org.slf4j.LoggerFactory
+
+import scala.sys.process.Process
+
+@Ignore
+class ExampleRNNSuite extends FunSuite with BeforeAndAfterAll {
+ private val logger = LoggerFactory.getLogger(classOf[ExampleRNNSuite])
+
+ override def beforeAll(): Unit = {
+ logger.info("Downloading LSTM model")
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ logger.info("tempDirPath: %s".format(tempDirPath))
+ val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/RNN/"
+ Util.downloadUrl(baseUrl + "obama.zip", tempDirPath + "/RNN/obama.zip")
+ Util.downloadUrl(baseUrl + "sherlockholmes.train.txt",
+ tempDirPath + "/RNN/sherlockholmes.train.txt")
+ Util.downloadUrl(baseUrl + "sherlockholmes.valid.txt",
+ tempDirPath + "/RNN/sherlockholmes.valid.txt")
+ // TODO: Need to confirm with Windows
+ Process(s"unzip $tempDirPath/RNN/obama.zip -d $tempDirPath/RNN/") !
+ }
+
+ test("Example CI: Test LSTM Bucketing") {
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ var ctx = Context.cpu()
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+ ctx = Context.gpu()
+ }
+ LstmBucketing.runTraining(tempDirPath + "/RNN/sherlockholmes.train.txt",
+ tempDirPath + "/RNN/sherlockholmes.valid.txt", Array(ctx), 1)
+ }
+
+ test("Example CI: Test TrainCharRNN") {
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+ val ctx = Context.gpu()
+ TrainCharRnn.runTrainCharRnn(tempDirPath + "/RNN/obama.txt",
+ tempDirPath, ctx, 1)
+ } else {
+ logger.info("CPU not supported for this test, skipped...")
+ }
+ }
+
+ test("Example CI: Test TestCharRNN") {
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ val ctx = Context.gpu()
+ TestCharRnn.runTestCharRNN(tempDirPath + "/RNN/obama.txt",
+ tempDirPath + "/RNN/obama", "The joke")
+ }
+}