You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/10/23 23:49:53 UTC

[incubator-mxnet] branch master updated: use ResourceScope in Model/Trainer/FeedForward.scala (#12882)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b4df85  use ResourceScope in Model/Trainer/FeedForward.scala (#12882)
6b4df85 is described below

commit 6b4df8576e373ff68b4fcd99ae6318ddb4b9ed12
Author: Naveen Swamy <mn...@gmail.com>
AuthorDate: Tue Oct 23 16:49:37 2018 -0700

    use ResourceScope in Model/Trainer/FeedForward.scala (#12882)
    
    * use ResourceScope in Model/Trainer/FeedForward.scala
    
    * add moveToOuterScope public method to move resources to a outerScope if it exists
    
    * fix memory leak in FeedForward.scala by making it a native resource and disposing argparams, auxParams
    in dispose() method
---
 .../main/scala/org/apache/mxnet/FeedForward.scala  | 152 +++++++++++++--------
 .../scala/org/apache/mxnet/NativeResource.scala    |   8 +-
 .../scala/org/apache/mxnet/ResourceScope.scala     |  35 +++--
 .../imclassification/TrainModel.scala              |  80 +++++------
 .../imclassification/util/Trainer.scala            | 133 +++++++++---------
 5 files changed, 230 insertions(+), 178 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
index 00a1450..2ed9d8c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
@@ -17,9 +17,10 @@
 
 package org.apache.mxnet
 
+import org.apache.mxnet.Base.CPtrAddress
 import org.apache.mxnet.io.NDArrayIter
 import org.apache.mxnet.optimizer.SGD
-import org.slf4j.{LoggerFactory, Logger}
+import org.slf4j.{Logger, LoggerFactory}
 
 import scala.collection.mutable.ListBuffer
 
@@ -55,7 +56,7 @@ class FeedForward private(
     argParams: Map[String, NDArray],
     auxParams: Map[String, NDArray],
     private val allowExtraParams: Boolean,
-    val beginEpoch: Int) {
+    val beginEpoch: Int) extends NativeResource {
 
   val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
   private var argumentChecked = false
@@ -126,6 +127,8 @@ class FeedForward private(
   }
 
   // Initialize weight parameters and auxiliary states
+  // The NDArrays associated with the _argParms and _auxParams are not disposed instead
+  // they are passed a outer scope if available.
   private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false)
   : (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = {
     val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
@@ -137,16 +140,26 @@ class FeedForward private(
     val paramNameShapes = (argNames zip argShapes).filter { case (name, _) =>
       paramNames.contains(name)
     }
-    val argParams = paramNameShapes.map { case (name, shape) =>
-      (name, NDArray.zeros(shape))
+    val argParams = paramNameShapes.map { case (name, shape) => {
+        val param = NDArray.zeros(shape)
+        val curScope = ResourceScope.getCurrentScope()
+        if (curScope.isDefined) curScope.get.moveToOuterScope(param)
+        (name, param)
+      }
     }.toMap
-    val auxParams = (auxNames zip auxShapes).map { case (name, shape) =>
-      (name, NDArray.zeros(shape))
+
+    val auxParams = (auxNames zip auxShapes).map { case (name, shape) => {
+        val param = NDArray.zeros(shape)
+        val curScope = ResourceScope.getCurrentScope()
+        if (curScope.isDefined) curScope.get.moveToOuterScope(param)
+        (name, param)
+      }
     }.toMap
 
     for ((k, v) <- argParams) {
       if (_argParams != null && _argParams.contains(k) && (!overwrite)) {
         argParams(k).set(_argParams(k))
+
       } else {
         initializer(k, v)
       }
@@ -277,13 +290,15 @@ class FeedForward private(
   def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric, kvStoreType: String,
           epochEndCallback: EpochEndCallback, batchEndCallback: BatchEndCallback,
           logger: Logger, workLoadList: Seq[Float]): Unit = {
-    // init params first to allow kv store use _argParams to decide its type
-    initSymbolParams(trainData)
-    // create kvstore
-    val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams)
-    fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
-      epochEndCallback, batchEndCallback, logger, workLoadList)
-    kvStore.foreach(_.dispose())
+    ResourceScope.using() {
+      // init params first to allow kv store use _argParams to decide its type
+      initSymbolParams(trainData)
+      // create kvstore
+      val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams)
+      fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
+        epochEndCallback, batchEndCallback, logger, workLoadList)
+//      kvStore.foreach(_.dispose())
+    }
   }
 
   def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
@@ -313,11 +328,13 @@ class FeedForward private(
           batchEndCallback: BatchEndCallback, logger: Logger,
           workLoadList: Seq[Float]): Unit = {
     // init params first to allow kv store use _argParams to decide its type
-    initSymbolParams(trainData)
-    // create kvstore
-    val (kvStore, updateOnKVStore) = Model.createKVStore(kv)
-    fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
-      epochEndCallback, batchEndCallback, logger, workLoadList)
+    ResourceScope.using() {
+      initSymbolParams(trainData)
+      // create kvstore
+      val (kvStore, updateOnKVStore) = Model.createKVStore(kv)
+      fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
+        epochEndCallback, batchEndCallback, logger, workLoadList)
+    }
   }
 
   def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
@@ -352,44 +369,49 @@ class FeedForward private(
                   batchEndCallback: BatchEndCallback = null, logger: Logger = FeedForward.logger,
                   workLoadList: Seq[Float] = null): Unit = {
     require(evalMetric != null, "evalMetric cannot be null")
-    val (argNames, paramNames, auxNames) = initSymbolParams(trainData)
-
-    // init optimizer
-    val batchSizeMultiplier = kvStore.map { kv =>
-      if (kv.`type` == "dist_sync") {
-        kv.numWorkers
-      } else {
-        1
-      }
-    }
-    val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
-    this.optimizer.setArgNames(argNames)
-    this.optimizer.setRescaleGrad(1f / batchSize)
-    this.optimizer.setSymbol(this.symbol)
-    val paramIdx2Name =
-      if (updateOnKVStore) {
-        paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap
-      } else {
-        paramNames.zipWithIndex.flatMap { case (name, idx) =>
-          (0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap
-        }.toMap
+    // TODO: https://issues.apache.org/jira/browse/MXNET-1171
+    // this leaks memory, initSymbolParams->initParams is already called which allocates
+    // NDArray in argParams, auxParams and here we are overwriting it by calling again.
+    // PhantomRef should take care of releasing this when GC is called, however we have to
+    // wait for the GC call to happen.
+      val (argNames, paramNames, auxNames) = initSymbolParams(trainData)
+
+      // init optimizer
+      val batchSizeMultiplier = kvStore.map { kv =>
+        if (kv.`type` == "dist_sync") {
+          kv.numWorkers
+        } else {
+          1
+        }
       }
-    this.optimizer.setIdx2Name(paramIdx2Name)
-
-    logger.debug("Start training on multi-device")
-    Model.trainMultiDevice(
-      symbol, ctx, argNames, paramNames, auxNames,
-      _argParams, _auxParams,
-      this.beginEpoch, this.numEpoch,
-      this.epochSize, this.optimizer,
-      kvStore, updateOnKVStore,
-      trainData = trainData, evalData = Option(evalData),
-      evalMetric = evalMetric,
-      epochEndCallback = Option(epochEndCallback),
-      batchEndCallback = Option(batchEndCallback),
-      workLoadList = workLoadList,
-      monitor = monitor,
-      symGen = symGen)
+      val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
+      this.optimizer.setArgNames(argNames)
+      this.optimizer.setRescaleGrad(1f / batchSize)
+      this.optimizer.setSymbol(this.symbol)
+      val paramIdx2Name =
+        if (updateOnKVStore) {
+          paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap
+        } else {
+          paramNames.zipWithIndex.flatMap { case (name, idx) =>
+            (0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap
+          }.toMap
+        }
+      this.optimizer.setIdx2Name(paramIdx2Name)
+
+      logger.debug("Start training on multi-device")
+      Model.trainMultiDevice(
+        symbol, ctx, argNames, paramNames, auxNames,
+        _argParams, _auxParams,
+        this.beginEpoch, this.numEpoch,
+        this.epochSize, this.optimizer,
+        kvStore, updateOnKVStore,
+        trainData = trainData, evalData = Option(evalData),
+        evalMetric = evalMetric,
+        epochEndCallback = Option(epochEndCallback),
+        batchEndCallback = Option(batchEndCallback),
+        workLoadList = workLoadList,
+        monitor = monitor,
+        symGen = symGen)
   }
 
   /**
@@ -416,9 +438,29 @@ class FeedForward private(
   def serialize(): Array[Byte] = {
     Model.serialize(this.symbol, getArgParams, getAuxParams)
   }
+
+  // hack to make the FeedForward.scala work with ResourceScope and
+  // automatically release _argParms and _auxParms
+  override def nativeAddress: CPtrAddress = hashCode()
+
+  override def nativeDeAllocator: CPtrAddress => Int = FeedForward.doNothingDeAllocator
+
+  override val ref: NativeResourceRef = super.register()
+
+  override val bytesAllocated: Long = 0L
+
+  override def dispose(): Unit = {
+    if (!super.isDisposed) {
+      _argParams.foreach { case (_, param) => param.dispose() }
+      _auxParams.foreach { case (_, param) => param.dispose() }
+    }
+  }
 }
 
 object FeedForward {
+
+  private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0
+
   private val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
   // Check if name is a data argument.
   private def isDataArg(name: String): Boolean = {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
index 48d4b0c..1806b86 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
@@ -46,7 +46,8 @@ private[mxnet] trait NativeResource
     */
   def nativeDeAllocator: (CPtrAddress => Int)
 
-  /** Call NativeResource.register to get the reference
+  /**
+    * Call NativeResource.register to get the reference
     */
   val ref: NativeResourceRef
 
@@ -56,6 +57,7 @@ private[mxnet] trait NativeResource
   // intentionally making it a val, so it gets evaluated when defined
   val bytesAllocated: Long
 
+  // this is set and unset by [[ResourceScope.add]] and [[ResourceScope.remove]]
   private[mxnet] var scope: Option[ResourceScope] = None
 
   @volatile private var disposed = false
@@ -69,11 +71,11 @@ private[mxnet] trait NativeResource
     *         using PhantomReference
     */
   def register(): NativeResourceRef = {
-    scope = ResourceScope.getCurrentScope()
+    val scope = ResourceScope.getCurrentScope()
     if (scope.isDefined) scope.get.add(this)
 
     NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated)
-    // register with PhantomRef tracking to release incase the objects go
+    // register with PhantomRef tracking to release in case the objects go
     // out of reference within scope but are held for long time
     NativeResourceRef.register(this, nativeDeAllocator)
  }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
index 1c5782d..30fe147 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
@@ -58,6 +58,7 @@ class ResourceScope extends AutoCloseable {
     */
   def add(resource: NativeResource): Unit = {
     resourceQ.+=(resource)
+    resource.scope = Some(this)
   }
 
   /**
@@ -67,7 +68,21 @@ class ResourceScope extends AutoCloseable {
     */
   def remove(resource: NativeResource): Unit = {
     resourceQ.-=(resource)
+    resource.scope = None
   }
+
+  /**
+    * Removes from current Scope and moves to outer scope if it exists
+    * @param resource Resource to be moved to an outer scope
+    */
+  def moveToOuterScope(resource: NativeResource): Unit = {
+    val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
+    if (prevScope.isDefined) {
+      this.remove(resource)
+      prevScope.get.add(resource)
+    } else this.remove(resource)
+  }
+
 }
 
 object ResourceScope {
@@ -92,32 +107,22 @@ object ResourceScope {
 
     val curScope = if (scope != null) scope else new ResourceScope()
 
-    val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
-
     @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = {
       g.foreach( n =>
         n match {
           case nRes: NativeResource => {
-            removeAndAddToPrevScope(nRes)
+            curScope.moveToOuterScope(nRes)
           }
           case kv: scala.Tuple2[_, _] => {
-            if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+            if (kv._1.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
               kv._1.asInstanceOf[NativeResource])
-            if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+            if (kv._2.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
               kv._2.asInstanceOf[NativeResource])
           }
         }
       )
     }
 
-    @inline def removeAndAddToPrevScope(r: NativeResource) = {
-      curScope.remove(r)
-      if (prevScope.isDefined)  {
-        prevScope.get.add(r)
-        r.scope = prevScope
-      }
-    }
-
     @inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = {
       if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed)
     }
@@ -129,8 +134,8 @@ object ResourceScope {
        ret match {
           // don't de-allocate if returning any collection that contains NativeResource.
         case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric)
-        case nRes: NativeResource => removeAndAddToPrevScope(nRes)
-        case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => removeAndAddToPrevScope(nd) )
+        case nRes: NativeResource => curScope.moveToOuterScope(nRes)
+        case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) )
         case _ => // do nothing
       }
       ret
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
index 608e191..f6c283c 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
@@ -43,7 +43,7 @@ object TrainModel {
     */
   def test(model: String, dataPath: String, numExamples: Int = 60000,
            numEpochs: Int = 10, benchmark: Boolean = false): Float = {
-    NDArrayCollector.auto().withScope {
+    ResourceScope.using() {
       val devs = Array(Context.cpu(0))
       val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
       val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath,
@@ -110,44 +110,46 @@ object TrainModel {
     val inst = new TrainModel
     val parser: CmdLineParser = new CmdLineParser(inst)
     try {
-      parser.parseArgument(args.toList.asJava)
-
-      val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
-      else inst.dataDir
-
-      val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, dataPath,
-        inst.numLayers, inst.numExamples, inst.benchmark)
-
-      val devs =
-        if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
-        else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
-        else Array(Context.cpu(0))
-
-      val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
-      envs.put("DMLC_ROLE", inst.role)
-      if (inst.schedulerHost != null) {
-        require(inst.schedulerPort > 0, "scheduler port not specified")
-        envs.put("DMLC_PS_ROOT_URI", inst.schedulerHost)
-        envs.put("DMLC_PS_ROOT_PORT", inst.schedulerPort.toString)
-        require(inst.numWorker > 0, "Num of workers must > 0")
-        envs.put("DMLC_NUM_WORKER", inst.numWorker.toString)
-        require(inst.numServer > 0, "Num of servers must > 0")
-        envs.put("DMLC_NUM_SERVER", inst.numServer.toString)
-        logger.info("Init PS environments")
-        KVStoreServer.init(envs.toMap)
-      }
-
-      if (inst.role != "worker") {
-        logger.info("Start KVStoreServer for scheduler & servers")
-        KVStoreServer.start()
-      } else {
-        Trainer.fit(batchSize = inst.batchSize, numExamples = inst.numExamples, devs = devs,
-          network = net, dataLoader = dataLoader,
-          kvStore = inst.kvStore, numEpochs = inst.numEpochs,
-          modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
-          lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch,
-          monitorSize = inst.monitor)
-        logger.info("Finish fit ...")
+      ResourceScope.using() {
+        parser.parseArgument(args.toList.asJava)
+
+        val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
+        else inst.dataDir
+
+        val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, dataPath,
+          inst.numLayers, inst.numExamples, inst.benchmark)
+
+        val devs =
+          if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
+          else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
+          else Array(Context.cpu(0))
+
+        val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
+        envs.put("DMLC_ROLE", inst.role)
+        if (inst.schedulerHost != null) {
+          require(inst.schedulerPort > 0, "scheduler port not specified")
+          envs.put("DMLC_PS_ROOT_URI", inst.schedulerHost)
+          envs.put("DMLC_PS_ROOT_PORT", inst.schedulerPort.toString)
+          require(inst.numWorker > 0, "Num of workers must > 0")
+          envs.put("DMLC_NUM_WORKER", inst.numWorker.toString)
+          require(inst.numServer > 0, "Num of servers must > 0")
+          envs.put("DMLC_NUM_SERVER", inst.numServer.toString)
+          logger.info("Init PS environments")
+          KVStoreServer.init(envs.toMap)
+        }
+
+        if (inst.role != "worker") {
+          logger.info("Start KVStoreServer for scheduler & servers")
+          KVStoreServer.start()
+        } else {
+          Trainer.fit(batchSize = inst.batchSize, numExamples = inst.numExamples, devs = devs,
+            network = net, dataLoader = dataLoader,
+            kvStore = inst.kvStore, numEpochs = inst.numEpochs,
+            modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
+            lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch,
+            monitorSize = inst.monitor)
+          logger.info("Finish fit ...")
+        }
       }
     } catch {
       case ex: Exception => {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
index 9a54e58..276816c 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
@@ -50,83 +50,84 @@ object Trainer {
           lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f,
           clipGradient: Float = 0f, monitorSize: Int = -1): Accuracy = {
     // kvstore
-    var kv = KVStore.create(kvStore)
+    ResourceScope.using()  {
+      var kv = KVStore.create(kvStore)
 
-    // load model
-    val modelPrefixWithRank =
-      if (modelPrefix == null) null
-      else modelPrefix + s"-${kv.rank}"
+      // load model
+      val modelPrefixWithRank =
+        if (modelPrefix == null) null
+        else modelPrefix + s"-${kv.rank}"
 
-    val (argParams, auxParams, beginEpoch) =
-      if (loadEpoch >= 0) {
-        require(modelPrefixWithRank != null)
-        val tmp = FeedForward.load(modelPrefix, loadEpoch)
-        (tmp.getArgParams, tmp.getAuxParams, loadEpoch)
-      } else {
-        (null, null, 0)
-      }
+      val (argParams, auxParams, beginEpoch) =
+        if (loadEpoch >= 0) {
+          require(modelPrefixWithRank != null)
+          val tmp = FeedForward.load(modelPrefix, loadEpoch)
+          (tmp.getArgParams, tmp.getAuxParams, loadEpoch)
+        } else {
+          (null, null, 0)
+        }
 
-    // save model
-    val checkpoint: EpochEndCallback =
-      if (modelPrefix == null) null
-      else new EpochEndCallback {
-        override def invoke(epoch: Int, symbol: Symbol,
-                            argParams: Map[String, NDArray],
-                            auxStates: Map[String, NDArray]): Unit = {
-          Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams, auxParams)
+      // save model
+      val checkpoint: EpochEndCallback =
+        if (modelPrefix == null) null
+        else new EpochEndCallback {
+          override def invoke(epoch: Int, symbol: Symbol,
+                              argParams: Map[String, NDArray],
+                              auxStates: Map[String, NDArray]): Unit = {
+            Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams, auxParams)
+          }
         }
-      }
 
-    // data
-    val (train, validation) = dataLoader(batchSize, kv)
+      // data
+      val (train, validation) = dataLoader(batchSize, kv)
 
-    // train
-    val epochSize =
-      if (kvStore == "dist_sync") numExamples / batchSize / kv.numWorkers
-      else numExamples / batchSize
+      // train
+      val epochSize =
+        if (kvStore == "dist_sync") numExamples / batchSize / kv.numWorkers
+        else numExamples / batchSize
 
-    val lrScheduler =
-      if (lrFactor < 1f) {
-        new FactorScheduler(step = Math.max((epochSize * lrFactorEpoch).toInt, 1),
-                            factor = lrFactor)
-      } else {
-        null
-      }
-    val optimizer: Optimizer = new SGD(learningRate = lr,
-      lrScheduler = lrScheduler, clipGradient = clipGradient,
-      momentum = 0.9f, wd = 0.00001f)
+      val lrScheduler =
+        if (lrFactor < 1f) {
+          new FactorScheduler(step = Math.max((epochSize * lrFactorEpoch).toInt, 1),
+                              factor = lrFactor)
+        } else {
+          null
+        }
+      val optimizer: Optimizer = new SGD(learningRate = lr,
+        lrScheduler = lrScheduler, clipGradient = clipGradient,
+        momentum = 0.9f, wd = 0.00001f)
 
-    // disable kvstore for single device
-    if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType != "gpu")) {
-      kv.dispose()
-      kv = null
-    }
+      // disable kvstore for single device
+      if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType != "gpu")) {
+        kv.dispose()
+        kv = null
+      }
 
-    val model = new FeedForward(ctx = devs,
-                                symbol = network,
-                                numEpoch = numEpochs,
-                                optimizer = optimizer,
-                                initializer = new Xavier(factorType = "in", magnitude = 2.34f),
-                                argParams = argParams,
-                                auxParams = auxParams,
-                                beginEpoch = beginEpoch,
-                                epochSize = epochSize)
-    if (monitorSize > 0) {
-      model.setMonitor(new Monitor(monitorSize))
-    }
-    val acc = new Accuracy()
-    model.fit(trainData = train,
-              evalData = validation,
-              evalMetric = acc,
-              kvStore = kv,
-              batchEndCallback = new Speedometer(batchSize, 50),
-              epochEndCallback = checkpoint)
-    if (kv != null) {
-      kv.dispose()
+      val model = new FeedForward(ctx = devs,
+                                  symbol = network,
+                                  numEpoch = numEpochs,
+                                  optimizer = optimizer,
+                                  initializer = new Xavier(factorType = "in", magnitude = 2.34f),
+                                  argParams = argParams,
+                                  auxParams = auxParams,
+                                  beginEpoch = beginEpoch,
+                                  epochSize = epochSize)
+      if (monitorSize > 0) {
+        model.setMonitor(new Monitor(monitorSize))
+      }
+      val acc = new Accuracy()
+      model.fit(trainData = train,
+                evalData = validation,
+                evalMetric = acc,
+                kvStore = kv,
+                batchEndCallback = new Speedometer(batchSize, 50),
+                epochEndCallback = checkpoint)
+      if (kv != null) {
+        kv.dispose()
+      }
+      acc
     }
-    acc
   }
-
   // scalastyle:on parameterNum
 }