You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/08/23 01:04:35 UTC

[incubator-mxnet] branch master updated: [MXNET-729] Use NDArrayCollector to fix memory leaks in Scala Examples (#12232)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f177d8  [MXNET-729] Use NDArrayCollector to fix memory leaks in Scala Examples (#12232)
2f177d8 is described below

commit 2f177d8a318fc9c0ad1b80a77ca82eeb4ab9f28e
Author: Lanking <la...@live.com>
AuthorDate: Wed Aug 22 18:04:25 2018 -0700

    [MXNET-729] Use NDArrayCollector to fix memory leaks in Scala Examples (#12232)
    
    * initial fix for RNN
    
    * add CI test
    
    * ignore the test due to memory leaks
    
    * release the GAN beast
    
    * enable rnn
    
    * add collector and dispose
    
    * revert the hacky thing after rebase
    
    * rename with inference
    
    * add collector in some examples
    
    * add experimental tag and comments
    
    * change the scope of the NDArrayCollector
    
    * apply final changes...
    
    * fix scalastyle
---
 .../scala/org/apache/mxnet/NDArrayCollector.scala  |   3 +
 .../org/apache/mxnet/annotation/Experimental.scala |   4 +
 .../CNNTextClassification.scala                    | 143 +++++++-------
 .../mxnetexamples/customop/ExampleCustomOp.scala   |  83 ++++----
 .../org/apache/mxnetexamples/gan/GanMnist.scala    | 137 ++++++-------
 .../imclassification/TrainMnist.scala              |  22 ++-
 .../imageclassifier/ImageClassifierExample.scala   |  60 +++---
 .../objectdetector/SSDClassifierExample.scala      |  56 +++---
 .../mxnetexamples/multitask/ExampleMultiTask.scala | 201 +++++++++----------
 .../mxnetexamples/neuralstyle/NeuralStyle.scala    | 175 ++++++++---------
 .../neuralstyle/end2end/BoostInference.scala       |  48 ++---
 .../neuralstyle/end2end/BoostTrain.scala           | 214 +++++++++++----------
 .../apache/mxnetexamples/rnn/LstmBucketing.scala   | 100 +++++-----
 .../org/apache/mxnetexamples/rnn/TestCharRnn.scala |  86 +++++----
 .../apache/mxnetexamples/rnn/TrainCharRnn.scala    | 202 +++++++++----------
 .../CNNClassifierExampleSuite.scala                |   2 +-
 .../apache/mxnetexamples/gan/GanExampleSuite.scala |   7 +-
 .../ImageClassifierExampleSuite.scala              |   6 +-
 .../ObjectDetectorExampleSuite.scala               |   8 +-
 .../mxnetexamples/multitask/MultiTaskSuite.scala   |  19 +-
 .../neuralstyle/NeuralStyleSuite.scala             |   2 +-
 .../apache/mxnetexamples/rnn/ExampleRNNSuite.scala |  11 +-
 22 files changed, 808 insertions(+), 781 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
index ea21cff..3952b73 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
@@ -18,6 +18,7 @@
 package org.apache.mxnet
 
 import org.apache.mxnet.Base.CPtrAddress
+import org.apache.mxnet.annotation.Experimental
 import org.slf4j.LoggerFactory
 
 import scala.annotation.varargs
@@ -80,6 +81,7 @@ object NDArrayCollector {
    * Create a collector allows users to later dispose the collected NDArray manually.
    * @return a manually-disposable collector.
    */
+  @Experimental
   def manual(): NDArrayCollector = new NDArrayCollector(false)
 
   /**
@@ -135,6 +137,7 @@ class NDArrayCollector private(private val autoDispose: Boolean = true,
    * @tparam T return type of the function <em>codeBlock</em>.
    * @return The result of function <em>codeBlock</em>.
    */
+  @Experimental
   def withScope[T](codeBlock: => T): T = {
     val old = NDArrayCollector.currCollector.get()
     NDArrayCollector.currCollector.set(this)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/annotation/Experimental.scala b/scala-package/core/src/main/scala/org/apache/mxnet/annotation/Experimental.scala
index 33d1d33..147d651 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/annotation/Experimental.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/annotation/Experimental.scala
@@ -19,6 +19,10 @@ package org.apache.mxnet.annotation
 
 import java.lang.annotation.{ElementType, Retention, Target, _}
 
+/**
+  * Experimental: there is a comparably high chance that
+  * the API will undergo some kind of changes
+  */
 @Retention(RetentionPolicy.RUNTIME)
 @Target(Array(ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
   ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE))
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
index 674c814..7745043 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
@@ -18,7 +18,7 @@
 package org.apache.mxnetexamples.cnntextclassification
 
 import org.apache.mxnet.optimizer.RMSProp
-import org.apache.mxnet.{Context, Executor, Model, NDArray, Optimizer, Shape, Symbol, Uniform}
+import org.apache.mxnet.{Context, Executor, Model, NDArray, NDArrayCollector, Optimizer, Shape, Symbol, Uniform}
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
 
@@ -131,56 +131,58 @@ object CNNTextClassification {
       numTotal = 0f
       updateRate = 0
 
-      for (begin <- 0 until trainBatches.length by batchSize) {
-        val (batchD, batchL) = {
-          if (begin + batchSize <= trainBatches.length) {
-            val datas = trainBatches.drop(begin).take(batchSize)
-            val labels = trainLabels.drop(begin).take(batchSize)
-            (datas, labels)
-          } else {
-            val right = (begin + batchSize) - trainBatches.length
-            val left = trainBatches.length - begin
-            val datas = trainBatches.drop(begin).take(left) ++ trainBatches.take(right)
-            val labels = trainLabels.drop(begin).take(left) ++ trainLabels.take(right)
-            (datas, labels)
+      NDArrayCollector.auto().withScope {
+        for (begin <- 0 until trainBatches.length by batchSize) {
+          val (batchD, batchL) = {
+            if (begin + batchSize <= trainBatches.length) {
+              val datas = trainBatches.drop(begin).take(batchSize)
+              val labels = trainLabels.drop(begin).take(batchSize)
+              (datas, labels)
+            } else {
+              val right = (begin + batchSize) - trainBatches.length
+              val left = trainBatches.length - begin
+              val datas = trainBatches.drop(begin).take(left) ++ trainBatches.take(right)
+              val labels = trainLabels.drop(begin).take(left) ++ trainLabels.take(right)
+              (datas, labels)
+            }
+          }
+          numTotal += batchSize
+          model.data.set(batchD.flatten.flatten)
+          model.label.set(batchL)
+
+          model.cnnExec.forward(isTrain = true)
+          model.cnnExec.backward()
+
+          val tmpCorrect = {
+            val predLabel = NDArray.api.argmax_channel(model.cnnExec.outputs(0))
+            val result = predLabel.toArray.zip(batchL).map { predLabel =>
+              if (predLabel._1 == predLabel._2) 1
+              else 0
+            }.sum.toFloat
+            predLabel.dispose()
+            result
           }
-        }
-        numTotal += batchSize
-        model.data.set(batchD.flatten.flatten)
-        model.label.set(batchL)
-
-        model.cnnExec.forward(isTrain = true)
-        model.cnnExec.backward()
-
-        val tmpCorrect = {
-          val predLabel = NDArray.api.argmax_channel(model.cnnExec.outputs(0))
-          val result = predLabel.toArray.zip(batchL).map { predLabel =>
-            if (predLabel._1 == predLabel._2) 1
-            else 0
-          }.sum.toFloat
-          predLabel.dispose()
-          result
-        }
 
-        numCorrect = numCorrect + tmpCorrect
-        val norm = Math.sqrt(paramBlocks.map { case (idx, weight, grad, state, name) =>
-          val temp = NDArray.api.norm(grad / batchSize).disposeDepsExcept(grad)
-          val l2Norm = temp.toScalar
-          temp.dispose()
-          l2Norm * l2Norm
-        }.sum).toFloat
-
-        if (updateRate % 2 == 0) {
-          paramBlocks.foreach { case (idx, weight, grad, state, name) =>
-            if (norm > maxGradNorm) {
-              grad.set(grad.toArray.map(_ * (maxGradNorm / norm)))
-              opt.update(idx, weight, grad, state)
+          numCorrect = numCorrect + tmpCorrect
+          val norm = Math.sqrt(paramBlocks.map { case (idx, weight, grad, state, name) =>
+            val temp = NDArray.api.norm(grad / batchSize).disposeDepsExcept(grad)
+            val l2Norm = temp.toScalar
+            temp.dispose()
+            l2Norm * l2Norm
+          }.sum).toFloat
+
+          if (updateRate % 2 == 0) {
+            paramBlocks.foreach { case (idx, weight, grad, state, name) =>
+              if (norm > maxGradNorm) {
+                grad.set(grad.toArray.map(_ * (maxGradNorm / norm)))
+                opt.update(idx, weight, grad, state)
+              }
+              else opt.update(idx, weight, grad, state)
+              grad.set(0f)
             }
-            else opt.update(idx, weight, grad, state)
-            grad.set(0f)
           }
+          updateRate = updateRate + 1
         }
-        updateRate = updateRate + 1
       }
 
       // decay learning rate
@@ -237,30 +239,33 @@ object CNNTextClassification {
 
   def test(w2vFilePath : String, mrDatasetPath: String,
            ctx : Context, saveModelPath: String) : Float = {
-    val (numEmbed, word2vec) = DataHelper.loadGoogleModel(w2vFilePath)
-    val (datas, labels) = DataHelper.loadMSDataWithWord2vec(
-      mrDatasetPath, numEmbed, word2vec)
-    // randomly shuffle data
-    val randIdx = Random.shuffle((0 until datas.length).toList)
-    // split train/dev set
-    val (trainDats, devDatas) = {
-      val train = randIdx.dropRight(1000).map(datas(_)).toArray
-      val dev = randIdx.takeRight(1000).map(datas(_)).toArray
-      (train, dev)
-    }
-    val (trainLabels, devLabels) = {
-      val train = randIdx.dropRight(1000).map(labels(_)).toArray
-      val dev = randIdx.takeRight(1000).map(labels(_)).toArray
-      (train, dev)
+    val output = NDArrayCollector.auto().withScope {
+      val (numEmbed, word2vec) = DataHelper.loadGoogleModel(w2vFilePath)
+      val (datas, labels) = DataHelper.loadMSDataWithWord2vec(
+        mrDatasetPath, numEmbed, word2vec)
+      // randomly shuffle data
+      val randIdx = Random.shuffle((0 until datas.length).toList)
+      // split train/dev set
+      val (trainDats, devDatas) = {
+        val train = randIdx.dropRight(1000).map(datas(_)).toArray
+        val dev = randIdx.takeRight(1000).map(datas(_)).toArray
+        (train, dev)
+      }
+      val (trainLabels, devLabels) = {
+        val train = randIdx.dropRight(1000).map(labels(_)).toArray
+        val dev = randIdx.takeRight(1000).map(labels(_)).toArray
+        (train, dev)
+      }
+      // reshpae for convolution input
+      val sentenceSize = datas(0).length
+      val batchSize = 100
+      val lr = 0.001f
+      val cnnModel = setupCnnModel(ctx, batchSize, sentenceSize, numEmbed)
+      val result = trainCNN(cnnModel, trainDats, trainLabels, devDatas, devLabels, batchSize,
+        saveModelPath, learningRate = lr)
+      result
     }
-    // reshpae for convolution input
-    val sentenceSize = datas(0).length
-    val batchSize = 100
-    val lr = 0.001f
-    val cnnModel = setupCnnModel(ctx, batchSize, sentenceSize, numEmbed)
-    val result = trainCNN(cnnModel, trainDats, trainLabels, devDatas, devLabels, batchSize,
-      saveModelPath, learningRate = lr)
-    result
+    output
   }
 
   def main(args: Array[String]): Unit = {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala
index a4b3479..df79f5b 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala
@@ -19,7 +19,7 @@ package org.apache.mxnetexamples.customop
 
 import org.apache.mxnet.Callback.Speedometer
 import org.apache.mxnet.DType.DType
-import org.apache.mxnet.{Accuracy, Context, CustomOp, CustomOpProp, NDArray, Operator, Shape, Symbol, Xavier}
+import org.apache.mxnet.{Accuracy, Context, CustomOp, CustomOpProp, NDArray, NDArrayCollector, Operator, Shape, Symbol, Xavier}
 import org.apache.mxnet.optimizer.RMSProp
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
@@ -141,49 +141,50 @@ object ExampleCustomOp {
       evalMetric.reset()
       var nBatch = 0
       var epochDone = false
-
-      trainIter.reset()
-      while (!epochDone) {
-        var doReset = true
-        while (doReset && trainIter.hasNext) {
-          val dataBatch = trainIter.next()
-          argDict("data").set(dataBatch.data(0))
-          argDict("label").set(dataBatch.label(0))
-          executor.forward(isTrain = true)
-          executor.backward()
-          paramsGrads.foreach { case (idx, name, grad, optimState) =>
-            opt.update(idx, argDict(name), grad, optimState)
+      NDArrayCollector.auto().withScope {
+        trainIter.reset()
+        while (!epochDone) {
+          var doReset = true
+          while (doReset && trainIter.hasNext) {
+            val dataBatch = trainIter.next()
+            argDict("data").set(dataBatch.data(0))
+            argDict("label").set(dataBatch.label(0))
+            executor.forward(isTrain = true)
+            executor.backward()
+            paramsGrads.foreach { case (idx, name, grad, optimState) =>
+              opt.update(idx, argDict(name), grad, optimState)
+            }
+            evalMetric.update(dataBatch.label, executor.outputs)
+            nBatch += 1
+            batchEndCallback.invoke(epoch, nBatch, evalMetric)
+          }
+          if (doReset) {
+            trainIter.reset()
           }
-          evalMetric.update(dataBatch.label, executor.outputs)
-          nBatch += 1
-          batchEndCallback.invoke(epoch, nBatch, evalMetric)
+          epochDone = true
         }
-        if (doReset) {
-          trainIter.reset()
+        val (name, value) = evalMetric.get
+        name.zip(value).foreach { case (n, v) =>
+          logger.info(s"Epoch[$epoch] Train-accuracy=$v")
+        }
+        val toc = System.currentTimeMillis
+        logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
+
+        evalMetric.reset()
+        testIter.reset()
+        while (testIter.hasNext) {
+          val evalBatch = testIter.next()
+          argDict("data").set(evalBatch.data(0))
+          argDict("label").set(evalBatch.label(0))
+          executor.forward(isTrain = true)
+          evalMetric.update(evalBatch.label, executor.outputs)
+          evalBatch.dispose()
+        }
+        val (names, values) = evalMetric.get
+        names.zip(values).foreach { case (n, v) =>
+          logger.info(s"Epoch[$epoch] Validation-accuracy=$v")
+          validationAcc = Math.max(validationAcc, v)
         }
-        epochDone = true
-      }
-      val (name, value) = evalMetric.get
-      name.zip(value).foreach { case (n, v) =>
-        logger.info(s"Epoch[$epoch] Train-accuracy=$v")
-      }
-      val toc = System.currentTimeMillis
-      logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
-
-      evalMetric.reset()
-      testIter.reset()
-      while (testIter.hasNext) {
-        val evalBatch = testIter.next()
-        argDict("data").set(evalBatch.data(0))
-        argDict("label").set(evalBatch.label(0))
-        executor.forward(isTrain = true)
-        evalMetric.update(evalBatch.label, executor.outputs)
-        evalBatch.dispose()
-      }
-      val (names, values) = evalMetric.get
-      names.zip(values).foreach { case (n, v) =>
-        logger.info(s"Epoch[$epoch] Validation-accuracy=$v")
-        validationAcc = Math.max(validationAcc, v)
       }
     }
     executor.dispose()
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
index 70846ee..475d91f 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
@@ -17,7 +17,7 @@
 
 package org.apache.mxnetexamples.gan
 
-import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, Shape, Symbol, Xavier}
+import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, NDArrayCollector, Shape, Symbol, Xavier}
 import org.apache.mxnet.optimizer.Adam
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
@@ -104,75 +104,80 @@ object GanMnist {
 
   def runTraining(dataPath : String, context : Context,
                   outputPath : String, numEpoch : Int): Float = {
-    val lr = 0.0005f
-    val beta1 = 0.5f
-    val batchSize = 100
-    val randShape = Shape(batchSize, 100)
-    val dataShape = Shape(batchSize, 1, 28, 28)
-
-    val (symGen, symDec) =
-      makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")
-
-    val gMod = new GANModule(
-      symGen,
-      symDec,
-      context = context,
-      dataShape = dataShape,
-      codeShape = randShape)
-
-    gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
-    gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))
-
-    gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))
-
-    val params = Map(
-      "image" -> s"$dataPath/train-images-idx3-ubyte",
-      "label" -> s"$dataPath/train-labels-idx1-ubyte",
-      "input_shape" -> s"(1, 28, 28)",
-      "batch_size" -> s"$batchSize",
-      "shuffle" -> "True"
-    )
-
-    val mnistIter = IO.MNISTIter(params)
-
-    val metricAcc = new CustomMetric(ferr, "ferr")
-
-    var t = 0
-    var dataBatch: DataBatch = null
-    var acc = 0.0f
-    for (epoch <- 0 until numEpoch) {
-      mnistIter.reset()
-      metricAcc.reset()
-      t = 0
-      while (mnistIter.hasNext) {
-        dataBatch = mnistIter.next()
-        gMod.update(dataBatch)
-        gMod.dLabel.set(0f)
-        metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
-        gMod.dLabel.set(1f)
-        metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)
-
-        if (t % 50 == 0) {
-          val (name, value) = metricAcc.get
-          acc = value(0)
-          logger.info(s"epoch: $epoch, iter $t, metric=${value.mkString(" ")}")
-          Viz.imSave("gout", outputPath, gMod.tempOutG(0), flip = true)
-          val diff = gMod.tempDiffD
-          val arr = diff.toArray
-          val mean = arr.sum / arr.length
-          val std = {
-            val tmpA = arr.map(a => (a - mean) * (a - mean))
-            Math.sqrt(tmpA.sum / tmpA.length).toFloat
+    val output = NDArrayCollector.auto().withScope {
+      val lr = 0.0005f
+      val beta1 = 0.5f
+      val batchSize = 100
+      val randShape = Shape(batchSize, 100)
+      val dataShape = Shape(batchSize, 1, 28, 28)
+
+      val (symGen, symDec) =
+        makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")
+
+      val gMod = new GANModule(
+        symGen,
+        symDec,
+        context = context,
+        dataShape = dataShape,
+        codeShape = randShape)
+
+      gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
+      gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))
+
+      gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))
+
+      val params = Map(
+        "image" -> s"$dataPath/train-images-idx3-ubyte",
+        "label" -> s"$dataPath/train-labels-idx1-ubyte",
+        "input_shape" -> s"(1, 28, 28)",
+        "batch_size" -> s"$batchSize",
+        "shuffle" -> "True"
+      )
+
+      val mnistIter = IO.MNISTIter(params)
+
+      val metricAcc = new CustomMetric(ferr, "ferr")
+
+      var t = 0
+      var dataBatch: DataBatch = null
+      var acc = 0.0f
+      for (epoch <- 0 until numEpoch) {
+        mnistIter.reset()
+        metricAcc.reset()
+        t = 0
+        while (mnistIter.hasNext) {
+          dataBatch = mnistIter.next()
+          NDArrayCollector.auto().withScope {
+            gMod.update(dataBatch)
+            gMod.dLabel.set(0f)
+            metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
+            gMod.dLabel.set(1f)
+            metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)
+
+            if (t % 50 == 0) {
+              val (name, value) = metricAcc.get
+              acc = value(0)
+              logger.info(s"epoch: $epoch, iter $t, metric=${value.mkString(" ")}")
+              Viz.imSave("gout", outputPath, gMod.tempOutG(0), flip = true)
+              val diff = gMod.tempDiffD
+              val arr = diff.toArray
+              val mean = arr.sum / arr.length
+              val std = {
+                val tmpA = arr.map(a => (a - mean) * (a - mean))
+                Math.sqrt(tmpA.sum / tmpA.length).toFloat
+              }
+              diff.set((diff - mean) / std + 0.5f)
+              Viz.imSave("diff", outputPath, diff, flip = true)
+              Viz.imSave("data", outputPath, dataBatch.data(0), flip = true)
+            }
           }
-          diff.set((diff - mean) / std + 0.5f)
-          Viz.imSave("diff", outputPath, diff, flip = true)
-          Viz.imSave("data", outputPath, dataBatch.data(0), flip = true)
+          dataBatch.dispose()
+          t += 1
         }
-
-        t += 1
       }
+      acc
     }
-    acc
+    output
   }
 
   def main(args: Array[String]): Unit = {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
index bd0ce45..2f024fd 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
@@ -93,16 +93,18 @@ object TrainMnist {
   }
 
   def test(dataPath : String) : Float = {
-    val (dataShape, net) = (Shape(784), getMlp)
-    val devs = Array(Context.cpu(0))
-    val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
-    val Acc = ModelTrain.fit(dataDir = dataPath,
-      batchSize = 128, numExamples = 60000, devs = devs,
-      network = net, dataLoader = getIterator(dataShape),
-      kvStore = "local", numEpochs = 10)
-    logger.info("Finish test fit ...")
-    val (_, num) = Acc.get
-    num(0)
+    NDArrayCollector.auto().withScope {
+      val (dataShape, net) = (Shape(784), getMlp)
+      val devs = Array(Context.cpu(0))
+      val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
+      val Acc = ModelTrain.fit(dataDir = dataPath,
+        batchSize = 128, numExamples = 60000, devs = devs,
+        network = net, dataLoader = getIterator(dataShape),
+        kvStore = "local", numEpochs = 10)
+      logger.info("Finish test fit ...")
+      val (_, num) = Acc.get
+      num(0)
+    }
   }
 
 
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
index 2a0d967..f6e4fe0 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
@@ -17,10 +17,9 @@
 
 package org.apache.mxnetexamples.infer.imageclassifier
 
-import org.apache.mxnet.Shape
+import org.apache.mxnet._
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
-import org.apache.mxnet.{DType, DataDesc, Context}
 import org.apache.mxnet.infer.ImageClassifier
 
 import scala.collection.JavaConverters._
@@ -43,47 +42,50 @@ object ImageClassifierExample {
   def runInferenceOnSingleImage(modelPathPrefix: String, inputImagePath: String,
                                 context: Array[Context]):
   IndexedSeq[IndexedSeq[(String, Float)]] = {
-    val dType = DType.Float32
-    val inputShape = Shape(1, 3, 224, 224)
+    NDArrayCollector.auto().withScope {
+      val dType = DType.Float32
+      val inputShape = Shape(1, 3, 224, 224)
 
-    val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
+      val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
 
-    // Create object of ImageClassifier class
-    val imgClassifier: ImageClassifier = new
-        ImageClassifier(modelPathPrefix, inputDescriptor, context)
+      // Create object of ImageClassifier class
+      val imgClassifier: ImageClassifier = new
+          ImageClassifier(modelPathPrefix, inputDescriptor, context)
 
-    // Loading single image from file and getting BufferedImage
-    val img = ImageClassifier.loadImageFromFile(inputImagePath)
+      // Loading single image from file and getting BufferedImage
+      val img = ImageClassifier.loadImageFromFile(inputImagePath)
 
-    // Running inference on single image
-    val output = imgClassifier.classifyImage(img, Some(5))
-
-    output
+      // Running inference on single image
+      val output = imgClassifier.classifyImage(img, Some(5))
+      output
+    }
   }
 
   def runInferenceOnBatchOfImage(modelPathPrefix: String, inputImageDir: String,
                                  context: Array[Context]):
   IndexedSeq[IndexedSeq[(String, Float)]] = {
-    val dType = DType.Float32
-    val inputShape = Shape(1, 3, 224, 224)
+    NDArrayCollector.auto().withScope {
+      val dType = DType.Float32
+      val inputShape = Shape(1, 3, 224, 224)
 
-    val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
+      val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
 
-    // Create object of ImageClassifier class
-    val imgClassifier: ImageClassifier = new
-        ImageClassifier(modelPathPrefix, inputDescriptor, context)
+      // Create object of ImageClassifier class
+      val imgClassifier: ImageClassifier = new
+          ImageClassifier(modelPathPrefix, inputDescriptor, context)
 
-    // Loading batch of images from the directory path
-    val batchFiles = generateBatches(inputImageDir, 20)
-    var outputList = IndexedSeq[IndexedSeq[(String, Float)]]()
+      // Loading batch of images from the directory path
+      val batchFiles = generateBatches(inputImageDir, 20)
+      var outputList = IndexedSeq[IndexedSeq[(String, Float)]]()
 
-    for (batchFile <- batchFiles) {
-      val imgList = ImageClassifier.loadInputBatch(batchFile)
-      // Running inference on batch of images loaded in previous step
-      outputList ++= imgClassifier.classifyImageBatch(imgList, Some(5))
-    }
+      for (batchFile <- batchFiles) {
+        val imgList = ImageClassifier.loadInputBatch(batchFile)
+        // Running inference on batch of images loaded in previous step
+        outputList ++= imgClassifier.classifyImageBatch(imgList, Some(5))
+      }
 
-    outputList
+      outputList
+    }
   }
 
   def generateBatches(inputImageDirPath: String, batchSize: Int = 100): List[List[String]] = {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
index 7c6c7ef..0edde9e 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
@@ -19,7 +19,7 @@ package org.apache.mxnetexamples.infer.objectdetector
 
 import java.io.File
 
-import org.apache.mxnet.{Context, DType, DataDesc, Shape}
+import org.apache.mxnet._
 import org.apache.mxnet.infer._
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
@@ -54,37 +54,41 @@ object SSDClassifierExample {
   def runObjectDetectionSingle(modelPathPrefix: String, inputImagePath: String,
                                context: Array[Context]):
   IndexedSeq[IndexedSeq[(String, Array[Float])]] = {
-    val dType = DType.Float32
-    val inputShape = Shape(1, 3, 512, 512)
-    // ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
-    val outputShape = Shape(1, 6132, 6)
-    val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
-    val img = ImageClassifier.loadImageFromFile(inputImagePath)
-    val objDetector = new ObjectDetector(modelPathPrefix, inputDescriptors, context)
-    val output = objDetector.imageObjectDetect(img, Some(3))
-
-    output
+    NDArrayCollector.auto().withScope {
+      val dType = DType.Float32
+      val inputShape = Shape(1, 3, 512, 512)
+      // ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
+      val outputShape = Shape(1, 6132, 6)
+      val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
+      val img = ImageClassifier.loadImageFromFile(inputImagePath)
+      val objDetector = new ObjectDetector(modelPathPrefix, inputDescriptors, context)
+      val output = objDetector.imageObjectDetect(img, Some(3))
+
+      output
+    }
   }
 
   def runObjectDetectionBatch(modelPathPrefix: String, inputImageDir: String,
                               context: Array[Context]):
   IndexedSeq[IndexedSeq[(String, Array[Float])]] = {
-    val dType = DType.Float32
-    val inputShape = Shape(1, 3, 512, 512)
-    // ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
-    val outputShape = Shape(1, 6132, 6)
-    val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
-    val objDetector = new ObjectDetector(modelPathPrefix, inputDescriptors, context)
-    // Loading batch of images from the directory path
-    val batchFiles = generateBatches(inputImageDir, 20)
-    var outputList = IndexedSeq[IndexedSeq[(String, Array[Float])]]()
-
-    for (batchFile <- batchFiles) {
-      val imgList = ImageClassifier.loadInputBatch(batchFile)
-      // Running inference on batch of images loaded in previous step
-      outputList ++= objDetector.imageBatchObjectDetect(imgList, Some(5))
+    NDArrayCollector.auto().withScope {
+      val dType = DType.Float32
+      val inputShape = Shape(1, 3, 512, 512)
+      // ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
+      val outputShape = Shape(1, 6132, 6)
+      val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
+      val objDetector = new ObjectDetector(modelPathPrefix, inputDescriptors, context)
+      // Loading batch of images from the directory path
+      val batchFiles = generateBatches(inputImageDir, 20)
+      var outputList = IndexedSeq[IndexedSeq[(String, Array[Float])]]()
+
+      for (batchFile <- batchFiles) {
+        val imgList = ImageClassifier.loadInputBatch(batchFile)
+        // Running inference on batch of images loaded in previous step
+        outputList ++= objDetector.imageBatchObjectDetect(imgList, Some(5))
+      }
+      outputList
     }
-    outputList
   }
 
   def generateBatches(inputImageDirPath: String, batchSize: Int = 100): List[List[String]] = {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
index 825e465..bfde558 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
@@ -25,11 +25,9 @@ import org.slf4j.LoggerFactory
 
 import scala.collection.JavaConverters._
 import org.apache.commons.io.FileUtils
-
-import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, NDArray, Shape, Symbol, Xavier}
+import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, Executor, NDArray, NDArrayCollector, Shape, Symbol, Xavier}
 import org.apache.mxnet.DType.DType
 import org.apache.mxnet.optimizer.RMSProp
-import org.apache.mxnet.Executor
 import org.apache.mxnetexamples.Util
 
 import scala.collection.immutable.ListMap
@@ -223,120 +221,123 @@ object ExampleMultiTask {
 
   def train(batchSize: Int, numEpoch: Int, ctx: Context, modelDirPath: String):
   (Executor, MultiAccuracy) = {
-    val lr = 0.001f
-    val network = ExampleMultiTask.buildNetwork()
-    val (trainIter, valIter) =
-      Data.mnistIterator(modelDirPath, batchSize = batchSize, inputShape = Shape(784))
-    val trainMultiIt = new MultiMnistIterator(trainIter)
-    val valMultiIter = new MultiMnistIterator(valIter)
-
-    val datasAndLabels = trainMultiIt.provideData ++ trainMultiIt.provideLabel
-
-    val (argShapes, outputShapes, auxShapes) = network.inferShape(trainMultiIt.provideData("data"))
-    val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
-
-    val argNames = network.listArguments
-    val argDict = argNames.zip(argShapes.map(NDArray.empty(_, ctx))).toMap
-
-    val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
-      !datasAndLabels.contains(name)
-    }.map(x => x._1 -> NDArray.empty(x._2, ctx)).toMap
-
-    argDict.foreach { case (name, ndArray) =>
-      if (!datasAndLabels.contains(name)) {
-        initializer.initWeight(name, ndArray)
+    NDArrayCollector.auto().withScope {
+      val lr = 0.001f
+      val network = ExampleMultiTask.buildNetwork()
+      val (trainIter, valIter) =
+        Data.mnistIterator(modelDirPath, batchSize = batchSize, inputShape = Shape(784))
+      val trainMultiIt = new MultiMnistIterator(trainIter)
+      val valMultiIter = new MultiMnistIterator(valIter)
+
+      val datasAndLabels = trainMultiIt.provideData ++ trainMultiIt.provideLabel
+
+      val (argShapes, outputShapes, auxShapes)
+      = network.inferShape(trainMultiIt.provideData("data"))
+      val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
+
+      val argNames = network.listArguments
+      val argDict = argNames.zip(argShapes.map(NDArray.empty(_, ctx))).toMap
+
+      val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
+        !datasAndLabels.contains(name)
+      }.map(x => x._1 -> NDArray.empty(x._2, ctx)).toMap
+
+      argDict.foreach { case (name, ndArray) =>
+        if (!datasAndLabels.contains(name)) {
+          initializer.initWeight(name, ndArray)
+        }
       }
-    }
-
-    val data = argDict("data")
-    val label1 = argDict("softmaxoutput0_label")
-    val label2 = argDict("softmaxoutput1_label")
-    val maxGradNorm = 0.5f
-    val executor = network.bind(ctx, argDict, gradDict)
-
-    val opt = new RMSProp(learningRate = lr, wd = 0.00001f)
 
-    val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
-      (idx, name, grad, opt.createState(idx, argDict(name)))
-    }
+      val data = argDict("data")
+      val label1 = argDict("softmaxoutput0_label")
+      val label2 = argDict("softmaxoutput1_label")
+      val maxGradNorm = 0.5f
+      val executor = network.bind(ctx, argDict, gradDict)
 
-    val evalMetric = new ExampleMultiTask.MultiAccuracy(num = 2, name = "multi_accuracy")
-    val batchEndCallback = new ExampleMultiTask.Speedometer(batchSize, 50)
+      val opt = new RMSProp(learningRate = lr, wd = 0.00001f)
 
-    for (epoch <- 0 until numEpoch) {
-      // Training phase
-      val tic = System.currentTimeMillis
-      evalMetric.reset()
-      var nBatch = 0
-      var epochDone = false
-      // Iterate over training data.
-      trainMultiIt.reset()
+      val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
+        (idx, name, grad, opt.createState(idx, argDict(name)))
+      }
 
-      while (!epochDone) {
-        var doReset = true
-        while (doReset && trainMultiIt.hasNext) {
-          val dataBatch = trainMultiIt.next()
+      val evalMetric = new ExampleMultiTask.MultiAccuracy(num = 2, name = "multi_accuracy")
+      val batchEndCallback = new ExampleMultiTask.Speedometer(batchSize, 50)
+
+      for (epoch <- 0 until numEpoch) {
+        // Training phase
+        val tic = System.currentTimeMillis
+        evalMetric.reset()
+        var nBatch = 0
+        var epochDone = false
+        // Iterate over training data.
+        trainMultiIt.reset()
+
+        while (!epochDone) {
+          var doReset = true
+          while (doReset && trainMultiIt.hasNext) {
+            val dataBatch = trainMultiIt.next()
+
+            data.set(dataBatch.data(0))
+            label1.set(dataBatch.label(0))
+            label2.set(dataBatch.label(1))
+
+            executor.forward(isTrain = true)
+            executor.backward()
+
+            val norm = Math.sqrt(paramsGrads.map { case (idx, name, grad, optimState) =>
+              val l2Norm = NDArray.api.norm(data = (grad / batchSize)).toScalar
+              l2Norm * l2Norm
+            }.sum).toFloat
+
+            paramsGrads.foreach { case (idx, name, grad, optimState) =>
+              if (norm > maxGradNorm) {
+                grad.set(grad.toArray.map(_ * (maxGradNorm / norm)))
+                opt.update(idx, argDict(name), grad, optimState)
+              } else opt.update(idx, argDict(name), grad, optimState)
+            }
 
-          data.set(dataBatch.data(0))
-          label1.set(dataBatch.label(0))
-          label2.set(dataBatch.label(1))
+            // evaluate at end, so out_cpu_array can lazy copy
+            evalMetric.update(dataBatch.label, executor.outputs)
 
-          executor.forward(isTrain = true)
-          executor.backward()
-
-          val norm = Math.sqrt(paramsGrads.map { case (idx, name, grad, optimState) =>
-            val l2Norm = NDArray.api.norm(data = (grad / batchSize)).toScalar
-            l2Norm * l2Norm
-          }.sum).toFloat
-
-          paramsGrads.foreach { case (idx, name, grad, optimState) =>
-            if (norm > maxGradNorm) {
-              grad.set(grad.toArray.map(_ * (maxGradNorm / norm)))
-              opt.update(idx, argDict(name), grad, optimState)
-            } else opt.update(idx, argDict(name), grad, optimState)
+            nBatch += 1
+            batchEndCallback.invoke(epoch, nBatch, evalMetric)
           }
-
-          // evaluate at end, so out_cpu_array can lazy copy
-          evalMetric.update(dataBatch.label, executor.outputs)
-
-          nBatch += 1
-          batchEndCallback.invoke(epoch, nBatch, evalMetric)
+          if (doReset) {
+            trainMultiIt.reset()
+          }
+          // this epoch is done
+          epochDone = true
         }
-        if (doReset) {
-          trainMultiIt.reset()
+        var nameVals = evalMetric.get
+        nameVals.foreach { case (name, value) =>
+          logger.info(s"Epoch[$epoch] Train-$name=$value")
         }
-        // this epoch is done
-        epochDone = true
-      }
-      var nameVals = evalMetric.get
-      nameVals.foreach { case (name, value) =>
-        logger.info(s"Epoch[$epoch] Train-$name=$value")
-      }
-      val toc = System.currentTimeMillis
-      logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
+        val toc = System.currentTimeMillis
+        logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
 
-      evalMetric.reset()
-      valMultiIter.reset()
-      while (valMultiIter.hasNext) {
-        val evalBatch = valMultiIter.next()
+        evalMetric.reset()
+        valMultiIter.reset()
+        while (valMultiIter.hasNext) {
+          val evalBatch = valMultiIter.next()
 
-        data.set(evalBatch.data(0))
-        label1.set(evalBatch.label(0))
-        label2.set(evalBatch.label(1))
+          data.set(evalBatch.data(0))
+          label1.set(evalBatch.label(0))
+          label2.set(evalBatch.label(1))
 
-        executor.forward(isTrain = true)
+          executor.forward(isTrain = true)
 
-        evalMetric.update(evalBatch.label, executor.outputs)
-        evalBatch.dispose()
-      }
+          evalMetric.update(evalBatch.label, executor.outputs)
+          evalBatch.dispose()
+        }
 
-      nameVals = evalMetric.get
-      nameVals.foreach { case (name, value) =>
-        logger.info(s"Epoch[$epoch] Validation-$name=$value")
+        nameVals = evalMetric.get
+        nameVals.foreach { case (name, value) =>
+          logger.info(s"Epoch[$epoch] Validation-$name=$value")
+        }
       }
-    }
 
-    (executor, evalMetric)
+      (executor, evalMetric)
+    }
   }
 
   def main(args: Array[String]): Unit = {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
index f98d725..1767cab 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
@@ -170,102 +170,103 @@ object NeuralStyle {
                   contentWeight : Float, tvWeight : Float, gaussianRadius : Int,
                   lr: Float, maxNumEpochs: Int, maxLongEdge: Int,
                   saveEpochs : Int, stopEps: Float) : Unit = {
+    NDArrayCollector.auto().withScope {
+      val contentNp = preprocessContentImage(contentImage, maxLongEdge, dev)
+      val styleNp = preprocessStyleImage(styleImage, contentNp.shape, dev)
+      val size = (contentNp.shape(2), contentNp.shape(3))
 
-    val contentNp = preprocessContentImage(contentImage, maxLongEdge, dev)
-    val styleNp = preprocessStyleImage(styleImage, contentNp.shape, dev)
-    val size = (contentNp.shape(2), contentNp.shape(3))
-
-    val (style, content) = ModelVgg19.getSymbol
-    val (gram, gScale) = styleGramSymbol(size, style)
-    var modelExecutor = ModelVgg19.getExecutor(gram, content, modelPath, size, dev)
-
-    modelExecutor.data.set(styleNp)
-    modelExecutor.executor.forward()
-
-    val styleArray = modelExecutor.style.map(_.copyTo(Context.cpu()))
-    modelExecutor.data.set(contentNp)
-    modelExecutor.executor.forward()
-    val contentArray = modelExecutor.content.copyTo(Context.cpu())
-
-    // delete the executor
-    modelExecutor.argDict.foreach(ele => ele._2.dispose())
-    modelExecutor.content.dispose()
-    modelExecutor.data.dispose()
-    modelExecutor.dataGrad.dispose()
-    modelExecutor.style.foreach(_.dispose())
-    modelExecutor.executor.dispose()
-    modelExecutor = null
-
-    val (styleLoss, contentLoss) = getLoss(gram, content)
-    modelExecutor = ModelVgg19.getExecutor(
-      styleLoss, contentLoss, modelPath, size, dev)
-
-    val gradArray = {
-      var tmpGA = Array[NDArray]()
-      for (i <- 0 until styleArray.length) {
-        modelExecutor.argDict(s"target_gram_$i").set(styleArray(i))
-        tmpGA = tmpGA :+ NDArray.ones(Shape(1), dev) * (styleWeight / gScale(i))
-      }
-      tmpGA :+ NDArray.ones(Shape(1), dev) * contentWeight
-    }
-
-    modelExecutor.argDict("target_content").set(contentArray)
-
-    // train
-    val img = Random.uniform(-0.1f, 0.1f, contentNp.shape, dev)
-    val lrFS = new FactorScheduler(step = 10, factor = 0.9f)
+      val (style, content) = ModelVgg19.getSymbol
+      val (gram, gScale) = styleGramSymbol(size, style)
+      var modelExecutor = ModelVgg19.getExecutor(gram, content, modelPath, size, dev)
 
-    saveImage(contentNp, s"${outputDir}/input.jpg", gaussianRadius)
-    saveImage(styleNp, s"${outputDir}/style.jpg", gaussianRadius)
-
-    val optimizer = new Adam(
-      learningRate = lr,
-      wd = 0.005f,
-      lrScheduler = lrFS)
-    val optimState = optimizer.createState(0, img)
-
-    logger.info(s"start training arguments")
-
-    var oldImg = img.copyTo(dev)
-    val clipNorm = img.shape.toVector.reduce(_ * _)
-    val tvGradExecutor = getTvGradExecutor(img, dev, tvWeight)
-    var eps = 0f
-    var trainingDone = false
-    var e = 0
-    while (e < maxNumEpochs && !trainingDone) {
-      modelExecutor.data.set(img)
+      modelExecutor.data.set(styleNp)
       modelExecutor.executor.forward()
-      modelExecutor.executor.backward(gradArray)
 
-      val gNorm = NDArray.norm(modelExecutor.dataGrad).toScalar
-      if (gNorm > clipNorm) {
-        modelExecutor.dataGrad.set(modelExecutor.dataGrad * (clipNorm / gNorm))
-      }
-      tvGradExecutor match {
-        case Some(executor) => {
-          executor.forward()
-          optimizer.update(0, img,
-            modelExecutor.dataGrad + executor.outputs(0),
-            optimState)
+      val styleArray = modelExecutor.style.map(_.copyTo(Context.cpu()))
+      modelExecutor.data.set(contentNp)
+      modelExecutor.executor.forward()
+      val contentArray = modelExecutor.content.copyTo(Context.cpu())
+
+      // delete the executor
+      modelExecutor.argDict.foreach(ele => ele._2.dispose())
+      modelExecutor.content.dispose()
+      modelExecutor.data.dispose()
+      modelExecutor.dataGrad.dispose()
+      modelExecutor.style.foreach(_.dispose())
+      modelExecutor.executor.dispose()
+      modelExecutor = null
+
+      val (styleLoss, contentLoss) = getLoss(gram, content)
+      modelExecutor = ModelVgg19.getExecutor(
+        styleLoss, contentLoss, modelPath, size, dev)
+
+      val gradArray = {
+        var tmpGA = Array[NDArray]()
+        for (i <- 0 until styleArray.length) {
+          modelExecutor.argDict(s"target_gram_$i").set(styleArray(i))
+          tmpGA = tmpGA :+ NDArray.ones(Shape(1), dev) * (styleWeight / gScale(i))
         }
-        case None =>
-          optimizer.update(0, img, modelExecutor.dataGrad, optimState)
+        tmpGA :+ NDArray.ones(Shape(1), dev) * contentWeight
       }
-      eps = (NDArray.norm(oldImg - img) / NDArray.norm(img)).toScalar
-      oldImg.set(img)
-      logger.info(s"epoch $e, relative change $eps")
 
-      if (eps < stopEps) {
-        logger.info("eps < args.stop_eps, training finished")
-        trainingDone = true
-      }
-      if ((e + 1) % saveEpochs == 0) {
-        saveImage(img, s"${outputDir}/tmp_${e + 1}.jpg", gaussianRadius)
+      modelExecutor.argDict("target_content").set(contentArray)
+
+      // train
+      val img = Random.uniform(-0.1f, 0.1f, contentNp.shape, dev)
+      val lrFS = new FactorScheduler(step = 10, factor = 0.9f)
+
+      saveImage(contentNp, s"${outputDir}/input.jpg", gaussianRadius)
+      saveImage(styleNp, s"${outputDir}/style.jpg", gaussianRadius)
+
+      val optimizer = new Adam(
+        learningRate = lr,
+        wd = 0.005f,
+        lrScheduler = lrFS)
+      val optimState = optimizer.createState(0, img)
+
+      logger.info(s"start training arguments")
+
+      var oldImg = img.copyTo(dev)
+      val clipNorm = img.shape.toVector.reduce(_ * _)
+      val tvGradExecutor = getTvGradExecutor(img, dev, tvWeight)
+      var eps = 0f
+      var trainingDone = false
+      var e = 0
+      while (e < maxNumEpochs && !trainingDone) {
+        modelExecutor.data.set(img)
+        modelExecutor.executor.forward()
+        modelExecutor.executor.backward(gradArray)
+
+        val gNorm = NDArray.norm(modelExecutor.dataGrad).toScalar
+        if (gNorm > clipNorm) {
+          modelExecutor.dataGrad.set(modelExecutor.dataGrad * (clipNorm / gNorm))
+        }
+        tvGradExecutor match {
+          case Some(executor) => {
+            executor.forward()
+            optimizer.update(0, img,
+              modelExecutor.dataGrad + executor.outputs(0),
+              optimState)
+          }
+          case None =>
+            optimizer.update(0, img, modelExecutor.dataGrad, optimState)
+        }
+        eps = (NDArray.norm(oldImg - img) / NDArray.norm(img)).toScalar
+        oldImg.set(img)
+        logger.info(s"epoch $e, relative change $eps")
+
+        if (eps < stopEps) {
+          logger.info("eps < args.stop_eps, training finished")
+          trainingDone = true
+        }
+        if ((e + 1) % saveEpochs == 0) {
+          saveImage(img, s"${outputDir}/tmp_${e + 1}.jpg", gaussianRadius)
+        }
+        e = e + 1
       }
-      e = e + 1
+      saveImage(img, s"${outputDir}/out.jpg", gaussianRadius)
+      logger.info("Finish fit ...")
     }
-    saveImage(img, s"${outputDir}/out.jpg", gaussianRadius)
-    logger.info("Finish fit ...")
   }
 
   def main(args: Array[String]): Unit = {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala
index 5410fb9..b1e6634 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala
@@ -17,7 +17,7 @@
 
 package org.apache.mxnetexamples.neuralstyle.end2end
 
-import org.apache.mxnet.{Context, Shape}
+import org.apache.mxnet.{Context, NDArrayCollector, Shape}
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
 
@@ -29,28 +29,32 @@ object BoostInference {
 
   def runInference(modelPath: String, outputPath: String, guassianRadius : Int,
                    inputImage : String, ctx : Context): Unit = {
-    val dShape = Shape(1, 3, 480, 640)
-    val clipNorm = 1.0f * dShape.product
-    // generator
-    val gens = Array(
-      GenV4.getModule("g0", dShape, ctx, isTrain = false),
-      GenV3.getModule("g1", dShape, ctx, isTrain = false),
-      GenV3.getModule("g2", dShape, ctx, isTrain = false),
-      GenV4.getModule("g3", dShape, ctx, isTrain = false)
-    )
-    gens.zipWithIndex.foreach { case (gen, i) =>
-      gen.loadParams(s"$modelPath/$i/v3_0002-0026000.params")
-    }
+    NDArrayCollector.auto().withScope {
+      val dShape = Shape(1, 3, 480, 640)
+      val clipNorm = 1.0f * dShape.product
+      // generator
+      val gens = Array(
+        GenV4.getModule("g0", dShape, ctx, isTrain = false),
+        GenV3.getModule("g1", dShape, ctx, isTrain = false),
+        GenV3.getModule("g2", dShape, ctx, isTrain = false),
+        GenV4.getModule("g3", dShape, ctx, isTrain = false)
+      )
+      gens.zipWithIndex.foreach { case (gen, i) =>
+        gen.loadParams(s"$modelPath/$i/v3_0002-0026000.params")
+      }
 
-    val contentNp =
-      DataProcessing.preprocessContentImage(s"$inputImage", dShape, ctx)
-    var data = Array(contentNp)
-    for (i <- 0 until gens.length) {
-      gens(i).forward(data.takeRight(1))
-      val newImg = gens(i).getOutputs()(0)
-      data :+= newImg
-      DataProcessing.saveImage(newImg, s"$outputPath/out_$i.jpg", guassianRadius)
-      logger.info(s"Converted image: $outputPath/out_$i.jpg")
+      val contentNp =
+        DataProcessing.preprocessContentImage(s"$inputImage", dShape, ctx)
+      var data = Array(contentNp)
+      for (i <- 0 until gens.length) {
+        NDArrayCollector.auto().withScope {
+          gens(i).forward(data.takeRight(1))
+          val newImg = gens(i).getOutputs()(0)
+          data :+= newImg
+          DataProcessing.saveImage(newImg, s"$outputPath/out_$i.jpg", guassianRadius)
+          logger.info(s"Converted image: $outputPath/out_$i.jpg")
+        }
+      }
     }
   }
 
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
index 08b4c85..8246f44 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
@@ -19,7 +19,7 @@ package org.apache.mxnetexamples.neuralstyle.end2end
 
 import java.io.File
 
-import org.apache.mxnet.{Context, Executor, NDArray, Shape, Symbol}
+import org.apache.mxnet.{Context, Executor, NDArray, NDArrayCollector, Shape, Symbol}
 import org.apache.mxnet.optimizer.SGD
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
@@ -56,117 +56,121 @@ object BoostTrain {
 
   def runTraining(dataPath : String, vggModelPath: String, ctx : Context,
                   styleImage : String, saveModelPath : String) : Unit = {
-    // params
-    val vggParams = NDArray.load2Map(vggModelPath)
-    val styleWeight = 1.2f
-    val contentWeight = 10f
-    val dShape = Shape(1, 3, 384, 384)
-    val clipNorm = 0.05f * dShape.product
-    val modelPrefix = "v3"
-    // init style
-    val styleNp = DataProcessing.preprocessStyleImage(styleImage, dShape, ctx)
-    var styleMod = Basic.getStyleModule("style", dShape, ctx, vggParams)
-    styleMod.forward(Array(styleNp))
-    val styleArray = styleMod.getOutputs().map(_.copyTo(Context.cpu()))
-    styleMod.dispose()
-    styleMod = null
-
-    // content
-    val contentMod = Basic.getContentModule("content", dShape, ctx, vggParams)
-
-    // loss
-    val (loss, gScale) = Basic.getLossModule("loss", dShape, ctx, vggParams)
-    val extraArgs = (0 until styleArray.length)
-      .map( i => s"target_gram_$i" -> styleArray(i)).toMap
-    loss.setParams(extraArgs)
-    var gradArray = Array[NDArray]()
-    for (i <- 0 until styleArray.length) {
-      gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * (styleWeight / gScale(i)))
-    }
-    gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * contentWeight)
-
-    // generator
-    val gens = Array(
-      GenV4.getModule("g0", dShape, ctx),
-      GenV3.getModule("g1", dShape, ctx),
-      GenV3.getModule("g2", dShape, ctx),
-      GenV4.getModule("g3", dShape, ctx)
-    )
-    gens.foreach { gen =>
-      val opt = new SGD(learningRate = 1e-4f,
-        momentum = 0.9f,
-        wd = 5e-3f,
-        clipGradient = 5f)
-      gen.initOptimizer(opt)
-    }
+    NDArrayCollector.auto().withScope {
+      // params
+      val vggParams = NDArray.load2Map(vggModelPath)
+      val styleWeight = 1.2f
+      val contentWeight = 10f
+      val dShape = Shape(1, 3, 384, 384)
+      val clipNorm = 0.05f * dShape.product
+      val modelPrefix = "v3"
+      // init style
+      val styleNp = DataProcessing.preprocessStyleImage(styleImage, dShape, ctx)
+      var styleMod = Basic.getStyleModule("style", dShape, ctx, vggParams)
+      styleMod.forward(Array(styleNp))
+      val styleArray = styleMod.getOutputs().map(_.copyTo(Context.cpu()))
+      styleMod.dispose()
+      styleMod = null
+
+      // content
+      val contentMod = Basic.getContentModule("content", dShape, ctx, vggParams)
+
+      // loss
+      val (loss, gScale) = Basic.getLossModule("loss", dShape, ctx, vggParams)
+      val extraArgs = (0 until styleArray.length)
+        .map(i => s"target_gram_$i" -> styleArray(i)).toMap
+      loss.setParams(extraArgs)
+      var gradArray = Array[NDArray]()
+      for (i <- 0 until styleArray.length) {
+        gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * (styleWeight / gScale(i)))
+      }
+      gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * contentWeight)
+
+      // generator
+      val gens = Array(
+        GenV4.getModule("g0", dShape, ctx),
+        GenV3.getModule("g1", dShape, ctx),
+        GenV3.getModule("g2", dShape, ctx),
+        GenV4.getModule("g3", dShape, ctx)
+      )
+      gens.foreach { gen =>
+        val opt = new SGD(learningRate = 1e-4f,
+          momentum = 0.9f,
+          wd = 5e-3f,
+          clipGradient = 5f)
+        gen.initOptimizer(opt)
+      }
 
-    var filelist = new File(dataPath).list().toList
-    val numImage = filelist.length
-    logger.info(s"Dataset size: $numImage")
+      var filelist = new File(dataPath).list().toList
+      val numImage = filelist.length
+      logger.info(s"Dataset size: $numImage")
 
-    val tvWeight = 1e-2f
+      val tvWeight = 1e-2f
 
-    val startEpoch = 0
-    val endEpoch = 3
+      val startEpoch = 0
+      val endEpoch = 3
 
-    for (k <- 0 until gens.length) {
-      val path = new File(s"${saveModelPath}/$k")
-      if (!path.exists()) path.mkdir()
-    }
+      for (k <- 0 until gens.length) {
+        val path = new File(s"${saveModelPath}/$k")
+        if (!path.exists()) path.mkdir()
+      }
 
-    // train
-    for (i <- startEpoch until endEpoch) {
-      filelist = Random.shuffle(filelist)
-      for (idx <- filelist.indices) {
-        var dataArray = Array[NDArray]()
-        var lossGradArray = Array[NDArray]()
-        val data =
-          DataProcessing.preprocessContentImage(s"${dataPath}/${filelist(idx)}", dShape, ctx)
-        dataArray = dataArray :+ data
-        // get content
-        contentMod.forward(Array(data))
-        // set target content
-        loss.setParams(Map("target_content" -> contentMod.getOutputs()(0)))
-        // gen_forward
-        for (k <- 0 until gens.length) {
-          gens(k).forward(dataArray.takeRight(1))
-          dataArray = dataArray :+ gens(k).getOutputs()(0)
-          // loss forward
-          loss.forward(dataArray.takeRight(1))
-          loss.backward(gradArray)
-          lossGradArray = lossGradArray :+ loss.getInputGrads()(0)
-        }
-        val grad = NDArray.zeros(data.shape, ctx)
-        for (k <- gens.length - 1 to 0 by -1) {
-          val tvGradExecutor = getTvGradExecutor(gens(k).getOutputs()(0), ctx, tvWeight)
-          tvGradExecutor.forward()
-          grad += lossGradArray(k) + tvGradExecutor.outputs(0)
-          val gNorm = NDArray.norm(grad)
-          if (gNorm.toScalar > clipNorm) {
-            grad *= clipNorm / gNorm.toScalar
-          }
-          gens(k).backward(Array(grad))
-          gens(k).update()
-          gNorm.dispose()
-          tvGradExecutor.dispose()
-        }
-        grad.dispose()
-        if (idx % 20 == 0) {
-          logger.info(s"Epoch $i: Image $idx")
-          for (k <- 0 until gens.length) {
-            val n = NDArray.norm(gens(k).getInputGrads()(0))
-            logger.info(s"Data Norm : ${n.toScalar / dShape.product}")
-            n.dispose()
-          }
-        }
-        if (idx % 1000 == 0) {
-          for (k <- 0 until gens.length) {
-            gens(k).saveParams(
-              s"${saveModelPath}/$k/${modelPrefix}_" +
-                s"${"%04d".format(i)}-${"%07d".format(idx)}.params")
+      // train
+      for (i <- startEpoch until endEpoch) {
+        NDArrayCollector.auto().withScope {
+          filelist = Random.shuffle(filelist)
+          for (idx <- filelist.indices) {
+            var dataArray = Array[NDArray]()
+            var lossGradArray = Array[NDArray]()
+            val data =
+              DataProcessing.preprocessContentImage(s"${dataPath}/${filelist(idx)}", dShape, ctx)
+            dataArray = dataArray :+ data
+            // get content
+            contentMod.forward(Array(data))
+            // set target content
+            loss.setParams(Map("target_content" -> contentMod.getOutputs()(0)))
+            // gen_forward
+            for (k <- 0 until gens.length) {
+              gens(k).forward(dataArray.takeRight(1))
+              dataArray = dataArray :+ gens(k).getOutputs()(0)
+              // loss forward
+              loss.forward(dataArray.takeRight(1))
+              loss.backward(gradArray)
+              lossGradArray = lossGradArray :+ loss.getInputGrads()(0)
+            }
+            val grad = NDArray.zeros(data.shape, ctx)
+            for (k <- gens.length - 1 to 0 by -1) {
+              val tvGradExecutor = getTvGradExecutor(gens(k).getOutputs()(0), ctx, tvWeight)
+              tvGradExecutor.forward()
+              grad += lossGradArray(k) + tvGradExecutor.outputs(0)
+              val gNorm = NDArray.norm(grad)
+              if (gNorm.toScalar > clipNorm) {
+                grad *= clipNorm / gNorm.toScalar
+              }
+              gens(k).backward(Array(grad))
+              gens(k).update()
+              gNorm.dispose()
+              tvGradExecutor.dispose()
+            }
+            grad.dispose()
+            if (idx % 20 == 0) {
+              logger.info(s"Epoch $i: Image $idx")
+              for (k <- 0 until gens.length) {
+                val n = NDArray.norm(gens(k).getInputGrads()(0))
+                logger.info(s"Data Norm : ${n.toScalar / dShape.product}")
+                n.dispose()
+              }
+            }
+            if (idx % 1000 == 0) {
+              for (k <- 0 until gens.length) {
+                gens(k).saveParams(
+                  s"${saveModelPath}/$k/${modelPrefix}_" +
+                    s"${"%04d".format(i)}-${"%07d".format(idx)}.params")
+              }
+            }
+            data.dispose()
           }
         }
-        data.dispose()
       }
     }
   }
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
index f7a01ba..8b2059d 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
@@ -62,56 +62,58 @@ object LstmBucketing {
 
   def runTraining(trainData : String, validationData : String,
                   ctx : Array[Context], numEpoch : Int): Unit = {
-    val batchSize = 32
-    val buckets = Array(10, 20, 30, 40, 50, 60)
-    val numHidden = 200
-    val numEmbed = 200
-    val numLstmLayer = 2
-
-    logger.info("Building vocab ...")
-    val vocab = BucketIo.defaultBuildVocab(trainData)
-
-    def BucketSymGen(key: AnyRef):
-    (Symbol, IndexedSeq[String], IndexedSeq[String]) = {
-      val seqLen = key.asInstanceOf[Int]
-      val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
-        numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
-      (sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
+    NDArrayCollector.auto().withScope {
+      val batchSize = 32
+      val buckets = Array(10, 20, 30, 40, 50, 60)
+      val numHidden = 200
+      val numEmbed = 200
+      val numLstmLayer = 2
+
+      logger.info("Building vocab ...")
+      val vocab = BucketIo.defaultBuildVocab(trainData)
+
+      def BucketSymGen(key: AnyRef):
+      (Symbol, IndexedSeq[String], IndexedSeq[String]) = {
+        val seqLen = key.asInstanceOf[Int]
+        val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
+          numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
+        (sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
+      }
+
+      val initC = (0 until numLstmLayer).map(l =>
+        (s"l${l}_init_c_beta", (batchSize, numHidden))
+      )
+      val initH = (0 until numLstmLayer).map(l =>
+        (s"l${l}_init_h_beta", (batchSize, numHidden))
+      )
+      val initStates = initC ++ initH
+
+      val dataTrain = new BucketSentenceIter(trainData, vocab,
+        buckets, batchSize, initStates)
+      val dataVal = new BucketSentenceIter(validationData, vocab,
+        buckets, batchSize, initStates)
+
+      val model = new BucketingModule(
+        symGen = BucketSymGen,
+        defaultBucketKey = dataTrain.defaultBucketKey,
+        contexts = ctx)
+
+      val fitParams = new FitParams()
+      fitParams.setEvalMetric(
+        new CustomMetric(perplexity, name = "perplexity"))
+      fitParams.setKVStore("device")
+      fitParams.setOptimizer(
+        new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
+      fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
+      fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))
+
+      logger.info("Start training ...")
+      model.fit(
+        trainData = dataTrain,
+        evalData = Some(dataVal),
+        numEpoch = numEpoch, fitParams)
+      logger.info("Finished training...")
     }
-
-    val initC = (0 until numLstmLayer).map(l =>
-      (s"l${l}_init_c_beta", (batchSize, numHidden))
-    )
-    val initH = (0 until numLstmLayer).map(l =>
-      (s"l${l}_init_h_beta", (batchSize, numHidden))
-    )
-    val initStates = initC ++ initH
-
-    val dataTrain = new BucketSentenceIter(trainData, vocab,
-      buckets, batchSize, initStates)
-    val dataVal = new BucketSentenceIter(validationData, vocab,
-      buckets, batchSize, initStates)
-
-    val model = new BucketingModule(
-      symGen = BucketSymGen,
-      defaultBucketKey = dataTrain.defaultBucketKey,
-      contexts = ctx)
-
-    val fitParams = new FitParams()
-    fitParams.setEvalMetric(
-      new CustomMetric(perplexity, name = "perplexity"))
-    fitParams.setKVStore("device")
-    fitParams.setOptimizer(
-      new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
-    fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
-    fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))
-
-    logger.info("Start training ...")
-    model.fit(
-      trainData = dataTrain,
-      evalData = Some(dataVal),
-      numEpoch = numEpoch, fitParams)
-    logger.info("Finished training...")
   }
 
   def main(args: Array[String]): Unit = {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
index 4786d5d..bd064db 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
@@ -30,54 +30,56 @@ object TestCharRnn {
 
   private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn])
 
-  def runTestCharRNN(dataPath: String, modelPrefix: String, starterSentence : String): Unit = {
-    // The batch size for training
-    val batchSize = 32
-    // We can support various length input
-    // For this problem, we cut each input sentence to length of 129
-    // So we only need fix length bucket
-    val buckets = List(129)
-    // hidden unit in LSTM cell
-    val numHidden = 512
-    // embedding dimension, which is, map a char to a 256 dim vector
-    val numEmbed = 256
-    // number of lstm layer
-    val numLstmLayer = 3
+  def runInferenceCharRNN(dataPath: String, modelPrefix: String, starterSentence : String): Unit = {
+    NDArrayCollector.auto().withScope {
+      // The batch size for training
+      val batchSize = 32
+      // We can support various length input
+      // For this problem, we cut each input sentence to length of 129
+      // So we only need fix length bucket
+      val buckets = List(129)
+      // hidden unit in LSTM cell
+      val numHidden = 512
+      // embedding dimension, which is, map a char to a 256 dim vector
+      val numEmbed = 256
+      // number of lstm layer
+      val numLstmLayer = 3
 
-    // build char vocabluary from input
-    val vocab = Utils.buildVocab(dataPath)
+      // build char vocabluary from input
+      val vocab = Utils.buildVocab(dataPath)
 
-    // load from check-point
-    val (_, argParams, _) = Model.loadCheckpoint(modelPrefix, 75)
+      // load from check-point
+      val (_, argParams, _) = Model.loadCheckpoint(modelPrefix, 75)
 
-    // build an inference model
-    val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
-      numHidden = numHidden, numEmbed = numEmbed,
-      numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
+      // build an inference model
+      val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
+        numHidden = numHidden, numEmbed = numEmbed,
+        numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
 
-    // generate a sequence of 1200 chars
-    val seqLength = 1200
-    val inputNdarray = NDArray.zeros(1)
-    val revertVocab = Utils.makeRevertVocab(vocab)
+      // generate a sequence of 1200 chars
+      val seqLength = 1200
+      val inputNdarray = NDArray.zeros(1)
+      val revertVocab = Utils.makeRevertVocab(vocab)
 
-    // Feel free to change the starter sentence
-    var output = starterSentence
-    val randomSample = true
-    var newSentence = true
-    val ignoreLength = output.length()
+      // Feel free to change the starter sentence
+      var output = starterSentence
+      val randomSample = true
+      var newSentence = true
+      val ignoreLength = output.length()
 
-    for (i <- 0 until seqLength) {
-      if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray)
-      else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray)
-      val prob = model.forward(inputNdarray, newSentence)
-      newSentence = false
-      val nextChar = Utils.makeOutput(prob, revertVocab, randomSample)
-      if (nextChar == "") newSentence = true
-      if (i >= ignoreLength) output = output ++ nextChar
-    }
+      for (i <- 0 until seqLength) {
+        if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray)
+        else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray)
+        val prob = model.forward(inputNdarray, newSentence)
+        newSentence = false
+        val nextChar = Utils.makeOutput(prob, revertVocab, randomSample)
+        if (nextChar == "") newSentence = true
+        if (i >= ignoreLength) output = output ++ nextChar
+      }
 
-    // Let's see what we can learned from char in Obama's speech.
-    logger.info(output)
+      // Let's see what we can learned from char in Obama's speech.
+      logger.info(output)
+    }
   }
 
   def main(args: Array[String]): Unit = {
@@ -86,7 +88,7 @@ object TestCharRnn {
     try {
       parser.parseArgument(args.toList.asJava)
       assert(stcr.dataPath != null && stcr.modelPrefix != null && stcr.starterSentence != null)
-      runTestCharRNN(stcr.dataPath, stcr.modelPrefix, stcr.starterSentence)
+      runInferenceCharRNN(stcr.dataPath, stcr.modelPrefix, stcr.starterSentence)
     } catch {
       case ex: Exception => {
         logger.error(ex.getMessage, ex)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
index fb59705..c90b763 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
@@ -33,125 +33,127 @@ object TrainCharRnn {
 
   def runTrainCharRnn(dataPath: String, saveModelPath: String,
                       ctx : Context, numEpoch : Int): Unit = {
-    // The batch size for training
-    val batchSize = 32
-    // We can support various length input
-    // For this problem, we cut each input sentence to length of 129
-    // So we only need fix length bucket
-    val buckets = Array(129)
-    // hidden unit in LSTM cell
-    val numHidden = 512
-    // embedding dimension, which is, map a char to a 256 dim vector
-    val numEmbed = 256
-    // number of lstm layer
-    val numLstmLayer = 3
-    // we will show a quick demo in 2 epoch
-    // learning rate
-    val learningRate = 0.001f
-    // we will use pure sgd without momentum
-    val momentum = 0.0f
-
-    val vocab = Utils.buildVocab(dataPath)
-
-    // generate symbol for a length
-    def symGen(seqLen: Int): Symbol = {
-      Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1,
-        numHidden = numHidden, numEmbed = numEmbed,
-        numLabel = vocab.size + 1, dropout = 0.2f)
-    }
+    NDArrayCollector.auto().withScope {
+      // The batch size for training
+      val batchSize = 32
+      // We can support various length input
+      // For this problem, we cut each input sentence to length of 129
+      // So we only need fix length bucket
+      val buckets = Array(129)
+      // hidden unit in LSTM cell
+      val numHidden = 512
+      // embedding dimension, which is, map a char to a 256 dim vector
+      val numEmbed = 256
+      // number of lstm layer
+      val numLstmLayer = 3
+      // we will show a quick demo in 2 epoch
+      // learning rate
+      val learningRate = 0.001f
+      // we will use pure sgd without momentum
+      val momentum = 0.0f
+
+      val vocab = Utils.buildVocab(dataPath)
+
+      // generate symbol for a length
+      def symGen(seqLen: Int): Symbol = {
+        Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size + 1,
+          numHidden = numHidden, numEmbed = numEmbed,
+          numLabel = vocab.size + 1, dropout = 0.2f)
+      }
 
-    // initalize states for LSTM
-    val initC = for (l <- 0 until numLstmLayer)
-      yield (s"l${l}_init_c_beta", (batchSize, numHidden))
-    val initH = for (l <- 0 until numLstmLayer)
-      yield (s"l${l}_init_h_beta", (batchSize, numHidden))
-    val initStates = initC ++ initH
+      // initalize states for LSTM
+      val initC = for (l <- 0 until numLstmLayer)
+        yield (s"l${l}_init_c_beta", (batchSize, numHidden))
+      val initH = for (l <- 0 until numLstmLayer)
+        yield (s"l${l}_init_h_beta", (batchSize, numHidden))
+      val initStates = initC ++ initH
 
-    val dataTrain = new BucketIo.BucketSentenceIter(dataPath, vocab, buckets,
-      batchSize, initStates, seperateChar = "\n",
-      text2Id = Utils.text2Id, readContent = Utils.readContent)
+      val dataTrain = new BucketIo.BucketSentenceIter(dataPath, vocab, buckets,
+        batchSize, initStates, seperateChar = "\n",
+        text2Id = Utils.text2Id, readContent = Utils.readContent)
 
-    // the network symbol
-    val symbol = symGen(buckets(0))
+      // the network symbol
+      val symbol = symGen(buckets(0))
 
-    val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel
-    val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels)
+      val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel
+      val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels)
 
-    val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
+      val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
 
-    val argNames = symbol.listArguments()
-    val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap
-    val auxNames = symbol.listAuxiliaryStates()
-    val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap
+      val argNames = symbol.listArguments()
+      val argDict = argNames.zip(argShapes.map(NDArray.zeros(_, ctx))).toMap
+      val auxNames = symbol.listAuxiliaryStates()
+      val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap
 
-    val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
-      !datasAndLabels.contains(name)
-    }.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap
+      val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
+        !datasAndLabels.contains(name)
+      }.map(x => x._1 -> NDArray.empty(x._2, ctx)).toMap
 
-    argDict.foreach { case (name, ndArray) =>
-      if (!datasAndLabels.contains(name)) {
-        initializer.initWeight(name, ndArray)
+      argDict.foreach { case (name, ndArray) =>
+        if (!datasAndLabels.contains(name)) {
+          initializer.initWeight(name, ndArray)
+        }
       }
-    }
 
-    val data = argDict("data")
-    val label = argDict("softmax_label")
+      val data = argDict("data")
+      val label = argDict("softmax_label")
 
-    val executor = symbol.bind(ctx, argDict, gradDict)
+      val executor = symbol.bind(ctx, argDict, gradDict)
 
-    val opt = new Adam(learningRate = learningRate, wd = 0.0001f)
+      val opt = new Adam(learningRate = learningRate, wd = 0.0001f)
 
-    val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
-      (idx, name, grad, opt.createState(idx, argDict(name)))
-    }
+      val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
+        (idx, name, grad, opt.createState(idx, argDict(name)))
+      }
 
-    val evalMetric = new CustomMetric(Utils.perplexity, "perplexity")
-    val batchEndCallback = new Callback.Speedometer(batchSize, 50)
-    val epochEndCallback = Utils.doCheckpoint(s"${saveModelPath}/obama")
-
-    for (epoch <- 0 until numEpoch) {
-      // Training phase
-      val tic = System.currentTimeMillis
-      evalMetric.reset()
-      var nBatch = 0
-      var epochDone = false
-      // Iterate over training data.
-      dataTrain.reset()
-      while (!epochDone) {
-        var doReset = true
-        while (doReset && dataTrain.hasNext) {
-          val dataBatch = dataTrain.next()
-
-          data.set(dataBatch.data(0))
-          label.set(dataBatch.label(0))
-          executor.forward(isTrain = true)
-          executor.backward()
-          paramsGrads.foreach { case (idx, name, grad, optimState) =>
-            opt.update(idx, argDict(name), grad, optimState)
+      val evalMetric = new CustomMetric(Utils.perplexity, "perplexity")
+      val batchEndCallback = new Callback.Speedometer(batchSize, 50)
+      val epochEndCallback = Utils.doCheckpoint(s"${saveModelPath}/obama")
+
+      for (epoch <- 0 until numEpoch) {
+        // Training phase
+        val tic = System.currentTimeMillis
+        evalMetric.reset()
+        var nBatch = 0
+        var epochDone = false
+        // Iterate over training data.
+        dataTrain.reset()
+        while (!epochDone) {
+          var doReset = true
+          while (doReset && dataTrain.hasNext) {
+            val dataBatch = dataTrain.next()
+
+            data.set(dataBatch.data(0))
+            label.set(dataBatch.label(0))
+            executor.forward(isTrain = true)
+            executor.backward()
+            paramsGrads.foreach { case (idx, name, grad, optimState) =>
+              opt.update(idx, argDict(name), grad, optimState)
+            }
+
+            // evaluate at end, so out_cpu_array can lazy copy
+            evalMetric.update(dataBatch.label, executor.outputs)
+
+            nBatch += 1
+            batchEndCallback.invoke(epoch, nBatch, evalMetric)
           }
-
-          // evaluate at end, so out_cpu_array can lazy copy
-          evalMetric.update(dataBatch.label, executor.outputs)
-
-          nBatch += 1
-          batchEndCallback.invoke(epoch, nBatch, evalMetric)
+          if (doReset) {
+            dataTrain.reset()
+          }
+          // this epoch is done
+          epochDone = true
         }
-        if (doReset) {
-          dataTrain.reset()
+        val (name, value) = evalMetric.get
+        name.zip(value).foreach { case (n, v) =>
+          logger.info(s"Epoch[$epoch] Train-$n=$v")
         }
-        // this epoch is done
-        epochDone = true
-      }
-      val (name, value) = evalMetric.get
-      name.zip(value).foreach { case (n, v) =>
-        logger.info(s"Epoch[$epoch] Train-$n=$v")
-      }
-      val toc = System.currentTimeMillis
-      logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
+        val toc = System.currentTimeMillis
+        logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
 
-      epochEndCallback.invoke(epoch, symbol, argDict, auxDict)
+        epochEndCallback.invoke(epoch, symbol, argDict, auxDict)
+      }
+      executor.dispose()
     }
-    executor.dispose()
   }
 
   def main(args: Array[String]): Unit = {
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
index 95c9823..44025c0 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
@@ -21,7 +21,7 @@ import java.io.File
 import java.net.URL
 
 import org.apache.commons.io.FileUtils
-import org.apache.mxnet.Context
+import org.apache.mxnet.{Context, NDArrayCollector}
 import org.apache.mxnetexamples.Util
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 import org.slf4j.LoggerFactory
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
index 96820ce..59faba9 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
@@ -18,14 +18,14 @@
 package org.apache.mxnetexamples.gan
 
 import java.io.File
-import org.apache.mxnet.Context
+
+import org.apache.mxnet.{Context, NDArrayCollector}
 import org.apache.mxnetexamples.Util
 import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
 import org.slf4j.LoggerFactory
 
 import scala.sys.process.Process
 
-@Ignore
 class GanExampleSuite extends FunSuite with BeforeAndAfterAll{
   private val logger = LoggerFactory.getLogger(classOf[GanExampleSuite])
 
@@ -44,7 +44,8 @@ class GanExampleSuite extends FunSuite with BeforeAndAfterAll{
 
         val context = Context.gpu()
 
-        val output = GanMnist.runTraining(modelDirPath, context, modelDirPath, 5)
+        val output = GanMnist.runTraining(modelDirPath, context, modelDirPath, 3)
+
         Process("rm -rf " + modelDirPath) !
 
         assert(output >= 0.0f)
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
index f0bb07b..34d3bc9 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
@@ -23,7 +23,7 @@ import java.io.File
 import java.net.URL
 
 import org.apache.commons.io.FileUtils
-import org.apache.mxnet.Context
+import org.apache.mxnet.{Context, NDArrayCollector}
 import org.apache.mxnetexamples.Util
 
 import sys.process.Process
@@ -64,10 +64,10 @@ class ImageClassifierExampleSuite extends FunSuite with BeforeAndAfterAll {
     }
 
     val output = ImageClassifierExample.runInferenceOnSingleImage(modelDirPath + "resnet-18",
-      inputImagePath, context)
+        inputImagePath, context)
 
     val outputList = ImageClassifierExample.runInferenceOnBatchOfImage(modelDirPath + "resnet-18",
-      inputImageDir, context)
+        inputImageDir, context)
 
     Process("rm -rf " + modelDirPath + " " + inputImageDir) !
 
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
index 31da385..addc837 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
@@ -21,7 +21,7 @@ import java.io.File
 import java.net.URL
 
 import org.apache.commons.io.FileUtils
-import org.apache.mxnet.Context
+import org.apache.mxnet.{Context, NDArrayCollector}
 import org.apache.mxnetexamples.Util
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 import org.slf4j.LoggerFactory
@@ -61,11 +61,11 @@ class ObjectDetectorExampleSuite extends FunSuite with BeforeAndAfterAll {
     }
 
     val output = SSDClassifierExample.runObjectDetectionSingle(modelDirPath + "resnet50_ssd_model",
-      inputImagePath, context)
+        inputImagePath, context)
 
     val outputList = SSDClassifierExample.runObjectDetectionBatch(
-      modelDirPath + "resnet50_ssd_model",
-      inputImageDir, context)
+        modelDirPath + "resnet50_ssd_model",
+        inputImageDir, context)
 
     Process("rm -rf " + modelDirPath + " " + inputImageDir) !
 
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
index b86f675..983978d 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala
@@ -17,26 +17,11 @@
 
 package org.apache.mxnetexamples.multitask
 
-import org.apache.commons.io.FileUtils
-import org.apache.mxnet.Context
-import org.scalatest.FunSuite
+import org.apache.mxnet._
 import org.slf4j.LoggerFactory
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.DataIter
-import org.apache.mxnet.DataBatch
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.Shape
-import org.apache.mxnet.EvalMetric
 import org.apache.mxnet.Context
-import org.apache.mxnet.Xavier
-import org.apache.mxnet.optimizer.RMSProp
-import java.io.File
-import java.net.URL
 
-import scala.sys.process.Process
-import scala.collection.immutable.ListMap
-import scala.collection.immutable.IndexedSeq
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import org.scalatest.FunSuite
 
 
 /**
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
index dc8fc5b..71c2b35 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.mxnetexamples.neuralstyle
 
-import org.apache.mxnet.Context
+import org.apache.mxnet.{Context, NDArrayCollector}
 import org.apache.mxnetexamples.Util
 import org.apache.mxnetexamples.neuralstyle.end2end.{BoostInference, BoostTrain}
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
index b393a43..14fb7b8 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
@@ -25,7 +25,6 @@ import org.slf4j.LoggerFactory
 
 import scala.sys.process.Process
 
-@Ignore
 class ExampleRNNSuite extends FunSuite with BeforeAndAfterAll {
   private val logger = LoggerFactory.getLogger(classOf[ExampleRNNSuite])
 
@@ -51,7 +50,7 @@ class ExampleRNNSuite extends FunSuite with BeforeAndAfterAll {
       ctx = Context.gpu()
     }
     LstmBucketing.runTraining(tempDirPath + "/RNN/sherlockholmes.train.txt",
-        tempDirPath + "/RNN/sherlockholmes.valid.txt", Array(ctx), 1)
+      tempDirPath + "/RNN/sherlockholmes.valid.txt", Array(ctx), 1)
   }
 
   test("Example CI: Test TrainCharRNN") {
@@ -60,16 +59,16 @@ class ExampleRNNSuite extends FunSuite with BeforeAndAfterAll {
       System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
       val ctx = Context.gpu()
       TrainCharRnn.runTrainCharRnn(tempDirPath + "/RNN/obama.txt",
-          tempDirPath, ctx, 1)
+        tempDirPath, ctx, 1)
     } else {
       logger.info("CPU not supported for this test, skipped...")
     }
   }
 
-  test("Example CI: Test TestCharRNN") {
+  test("Example CI: Test Inference on CharRNN") {
     val tempDirPath = System.getProperty("java.io.tmpdir")
     val ctx = Context.gpu()
-    TestCharRnn.runTestCharRNN(tempDirPath + "/RNN/obama.txt",
-        tempDirPath + "/RNN/obama", "The joke")
+    TestCharRnn.runInferenceCharRNN(tempDirPath + "/RNN/obama.txt",
+      tempDirPath + "/RNN/obama", "The joke")
   }
 }