You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/19 20:59:01 UTC

[GitHub] nswamy closed pull request #12647: NativeResource Management in Scala

nswamy closed pull request #12647: NativeResource Management in Scala
URL: https://github.com/apache/incubator-mxnet/pull/12647
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index ea3a2d68c9f..e93169f08fa 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -123,5 +123,12 @@
       <artifactId>commons-io</artifactId>
       <version>2.1</version>
     </dependency>
+    <!-- https://mvnrepository.com/artifact/org.mockito/mockito-all -->
+    <dependency>
+      <groupId>org.mockito</groupId>
+      <artifactId>mockito-all</artifactId>
+      <version>1.10.19</version>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
 </project>
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index fc791d5cd9a..19fb6fe5cee 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -45,7 +45,7 @@ object Executor {
  * @see Symbol.bind : to create executor
  */
 class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
-                              private[mxnet] val symbol: Symbol) extends WarnIfNotDisposed {
+                              private[mxnet] val symbol: Symbol) extends NativeResource {
   private[mxnet] var argArrays: Array[NDArray] = null
   private[mxnet] var gradArrays: Array[NDArray] = null
   private[mxnet] var auxArrays: Array[NDArray] = null
@@ -59,14 +59,15 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
   private[mxnet] var _group2ctx: Map[String, Context] = null
   private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])
 
-  private var disposed = false
-  protected def isDisposed = disposed
-
-  def dispose(): Unit = {
-    if (!disposed) {
-      outputs.foreach(_.dispose())
-      _LIB.mxExecutorFree(handle)
-      disposed = true
+  override def nativeAddress: CPtrAddress = handle
+  override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
+  // cannot determine the off-heap size of this object
+  override val bytesAllocated: Long = 0
+  override val ref: NativeResourceRef = super.register()
+  override def dispose(): Unit = {
+    if (!super.isDisposed) {
+      super.dispose()
+      outputs.foreach(o => o.dispose())
     }
   }
 
@@ -305,4 +306,5 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
     checkCall(_LIB.mxExecutorPrint(handle, str))
     str.value
   }
+
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala b/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala
index 8e89ce76b87..45189a13aef 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala
@@ -52,22 +52,17 @@ object KVStore {
   }
 }
 
-class KVStore(private[mxnet] val handle: KVStoreHandle) extends WarnIfNotDisposed {
+class KVStore(private[mxnet] val handle: KVStoreHandle) extends NativeResource {
   private val logger: Logger = LoggerFactory.getLogger(classOf[KVStore])
   private var updaterFunc: MXKVStoreUpdater = null
-  private var disposed = false
-  protected def isDisposed = disposed
 
-  /**
-   * Release the native memory.
-   * The object shall never be used after it is disposed.
-   */
-  def dispose(): Unit = {
-    if (!disposed) {
-      _LIB.mxKVStoreFree(handle)
-      disposed = true
-    }
-  }
+  override def nativeAddress: CPtrAddress = handle
+
+  override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxKVStoreFree
+
+  override val ref: NativeResourceRef = super.register()
+
+  override val bytesAllocated: Long = 0L
 
   /**
    * Initialize a single or a sequence of key-value pairs into the store.
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
index 4bb9cdd331a..b835c4964dd 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
@@ -259,7 +259,9 @@ object Model {
                                       workLoadList: Seq[Float] = Nil,
                                       monitor: Option[Monitor] = None,
                                       symGen: SymbolGenerator = null): Unit = {
-    val executorManager = new DataParallelExecutorManager(
+    ResourceScope.using() {
+
+      val executorManager = new DataParallelExecutorManager(
         symbol = symbol,
         symGen = symGen,
         ctx = ctx,
@@ -269,17 +271,17 @@ object Model {
         auxNames = auxNames,
         workLoadList = workLoadList)
 
-    monitor.foreach(executorManager.installMonitor)
-    executorManager.setParams(argParams, auxParams)
+      monitor.foreach(executorManager.installMonitor)
+      executorManager.setParams(argParams, auxParams)
 
-    // updater for updateOnKVStore = false
-    val updaterLocal = Optimizer.getUpdater(optimizer)
+      // updater for updateOnKVStore = false
+      val updaterLocal = Optimizer.getUpdater(optimizer)
 
-    kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
-      argParams, executorManager.paramNames, updateOnKVStore))
-    if (updateOnKVStore) {
-      kvStore.foreach(_.setOptimizer(optimizer))
-    }
+      kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
+        argParams, executorManager.paramNames, updateOnKVStore))
+      if (updateOnKVStore) {
+        kvStore.foreach(_.setOptimizer(optimizer))
+      }
 
     // Now start training
     for (epoch <- beginEpoch until endEpoch) {
@@ -290,45 +292,46 @@ object Model {
       var epochDone = false
       // Iterate over training data.
       trainData.reset()
-      while (!epochDone) {
-        var doReset = true
-        while (doReset && trainData.hasNext) {
-          val dataBatch = trainData.next()
-          executorManager.loadDataBatch(dataBatch)
-          monitor.foreach(_.tic())
-          executorManager.forward(isTrain = true)
-          executorManager.backward()
-          if (updateOnKVStore) {
-            updateParamsOnKVStore(executorManager.paramArrays,
-              executorManager.gradArrays,
-              kvStore, executorManager.paramNames)
-          } else {
-            updateParams(executorManager.paramArrays,
-              executorManager.gradArrays,
-              updaterLocal, ctx.length,
-              executorManager.paramNames,
-              kvStore)
-          }
-          monitor.foreach(_.tocPrint())
-          // evaluate at end, so out_cpu_array can lazy copy
-          executorManager.updateMetric(evalMetric, dataBatch.label)
+      ResourceScope.using() {
+        while (!epochDone) {
+          var doReset = true
+          while (doReset && trainData.hasNext) {
+            val dataBatch = trainData.next()
+            executorManager.loadDataBatch(dataBatch)
+            monitor.foreach(_.tic())
+            executorManager.forward(isTrain = true)
+            executorManager.backward()
+            if (updateOnKVStore) {
+              updateParamsOnKVStore(executorManager.paramArrays,
+                executorManager.gradArrays,
+                kvStore, executorManager.paramNames)
+            } else {
+              updateParams(executorManager.paramArrays,
+                executorManager.gradArrays,
+                updaterLocal, ctx.length,
+                executorManager.paramNames,
+                kvStore)
+            }
+            monitor.foreach(_.tocPrint())
+            // evaluate at end, so out_cpu_array can lazy copy
+            executorManager.updateMetric(evalMetric, dataBatch.label)
 
-          nBatch += 1
-          batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))
+            nBatch += 1
+            batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))
 
-          // this epoch is done possibly earlier
-          if (epochSize != -1 && nBatch >= epochSize) {
-            doReset = false
+            // this epoch is done possibly earlier
+            if (epochSize != -1 && nBatch >= epochSize) {
+              doReset = false
+            }
+          }
+          if (doReset) {
+            trainData.reset()
           }
-        }
-        if (doReset) {
-          trainData.reset()
-        }
 
-        // this epoch is done
-        epochDone = (epochSize == -1 || nBatch >= epochSize)
+          // this epoch is done
+          epochDone = (epochSize == -1 || nBatch >= epochSize)
+        }
       }
-
       val (name, value) = evalMetric.get
       name.zip(value).foreach { case (n, v) =>
         logger.info(s"Epoch[$epoch] Train-$n=$v")
@@ -336,20 +339,22 @@ object Model {
       val toc = System.currentTimeMillis
       logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
 
-      evalData.foreach { evalDataIter =>
-        evalMetric.reset()
-        evalDataIter.reset()
-        // TODO: make DataIter implement Iterator
-        while (evalDataIter.hasNext) {
-          val evalBatch = evalDataIter.next()
-          executorManager.loadDataBatch(evalBatch)
-          executorManager.forward(isTrain = false)
-          executorManager.updateMetric(evalMetric, evalBatch.label)
-        }
+      ResourceScope.using() {
+        evalData.foreach { evalDataIter =>
+          evalMetric.reset()
+          evalDataIter.reset()
+          // TODO: make DataIter implement Iterator
+          while (evalDataIter.hasNext) {
+            val evalBatch = evalDataIter.next()
+            executorManager.loadDataBatch(evalBatch)
+            executorManager.forward(isTrain = false)
+            executorManager.updateMetric(evalMetric, evalBatch.label)
+          }
 
-        val (name, value) = evalMetric.get
-        name.zip(value).foreach { case (n, v) =>
-          logger.info(s"Epoch[$epoch] Train-$n=$v")
+          val (name, value) = evalMetric.get
+          name.zip(value).foreach { case (n, v) =>
+            logger.info(s"Epoch[$epoch] Validation-$n=$v")
+          }
         }
       }
 
@@ -359,8 +364,7 @@ object Model {
       epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams))
     }
 
-    updaterLocal.dispose()
-    executorManager.dispose()
+    }
   }
   // scalastyle:on parameterNum
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 9b6a7dc6654..f2a7603caa8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -562,16 +562,20 @@ object NDArray extends NDArrayBase {
  */
 class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
                              val writable: Boolean = true,
-                             addToCollector: Boolean = true) extends WarnIfNotDisposed {
+                             addToCollector: Boolean = true) extends NativeResource {
   if (addToCollector) {
     NDArrayCollector.collect(this)
   }
 
+  override def nativeAddress: CPtrAddress = handle
+  override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree
+  override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product
+
+  override val ref: NativeResourceRef = super.register()
+
   // record arrays who construct this array instance
   // we use weak reference to prevent gc blocking
   private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
-  @volatile private var disposed = false
-  def isDisposed: Boolean = disposed
 
   def serialize(): Array[Byte] = {
     val buf = ArrayBuffer.empty[Byte]
@@ -584,11 +588,10 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
    * The NDArrays it depends on will NOT be disposed. <br />
    * The object shall never be used after it is disposed.
    */
-  def dispose(): Unit = {
-    if (!disposed) {
-      _LIB.mxNDArrayFree(handle)
+  override def dispose(): Unit = {
+    if (!super.isDisposed) {
+      super.dispose()
       dependencies.clear()
-      disposed = true
     }
   }
 
@@ -1034,6 +1037,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     // TODO: naive implementation
     shape.hashCode + toArray.hashCode
   }
+
 }
 
 private[mxnet] object NDArrayConversions {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
new file mode 100644
index 00000000000..48d4b0c193b
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.Base.CPtrAddress
+import java.lang.ref.{PhantomReference, ReferenceQueue, WeakReference}
+import java.util.concurrent._
+
+import org.apache.mxnet.Base.checkCall
+import java.util.concurrent.atomic.AtomicLong
+
+
+/**
+  * NativeResource trait is used to manage MXNet Objects
+  * such as NDArray, Symbol, Executor, etc.,
+  * The MXNet Object calls NativeResource.register
+  * and assign the returned NativeResourceRef to PhantomReference
+  * NativeResource also implements AutoCloseable so MXNetObjects
+  * can be used like Resources in try-with-resources paradigm
+  */
+private[mxnet] trait NativeResource
+  extends AutoCloseable with WarnIfNotDisposed {
+
+  /**
+    * native Address associated with this object
+    */
+  def nativeAddress: CPtrAddress
+
+  /**
+    * Function Pointer to the NativeDeAllocator of nativeAddress
+    */
+  def nativeDeAllocator: (CPtrAddress => Int)
+
+  /** Call NativeResource.register to get the reference
+    */
+  val ref: NativeResourceRef
+
+  /**
+    * Off-Heap Bytes Allocated for this object
+    */
+  // intentionally making it a val, so it gets evaluated when defined
+  val bytesAllocated: Long
+
+  private[mxnet] var scope: Option[ResourceScope] = None
+
+  @volatile private var disposed = false
+
+  override def isDisposed: Boolean = disposed || isDeAllocated
+
+  /**
+    * Register this object for PhantomReference tracking and in
+    * ResourceScope if used inside ResourceScope.
+    * @return NativeResourceRef that tracks reachability of this object
+    *         using PhantomReference
+    */
+  def register(): NativeResourceRef = {
+    scope = ResourceScope.getCurrentScope()
+    if (scope.isDefined) scope.get.add(this)
+
+    NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated)
+    // register with PhantomRef tracking to release incase the objects go
+    // out of reference within scope but are held for long time
+    NativeResourceRef.register(this, nativeDeAllocator)
+ }
+
+  // Implements [[@link AutoCloseable.close]]
+  override def close(): Unit = {
+    dispose()
+  }
+
+  // Implements [[@link WarnIfNotDisposed.dispose]]
+  def dispose(): Unit = dispose(true)
+
+  /**
+    * This method deAllocates nativeResource and deRegisters
+    * from PhantomRef and removes from Scope if
+    * removeFromScope is set to true.
+    * @param removeFromScope remove from the currentScope if true
+    */
+  // the parameter here controls whether to remove from current scope.
+  // [[ResourceScope.close]] calls NativeResource.dispose
+  // if we remove from the ResourceScope ie., from the container in ResourceScope.
+  // while iterating on the container, calling iterator.next is undefined and not safe.
+  // Note that ResourceScope automatically disposes all the resources within.
+  private[mxnet] def dispose(removeFromScope: Boolean = true): Unit = {
+    if (!disposed) {
+      checkCall(nativeDeAllocator(this.nativeAddress))
+      NativeResourceRef.deRegister(ref) // removes from PhantomRef tracking
+      if (removeFromScope && scope.isDefined) scope.get.remove(this)
+      NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated)
+      disposed = true
+    }
+  }
+
+  /*
+  this is used by the WarnIfNotDisposed finalizer,
+  the object could be disposed by the GC without the need for explicit disposal
+  but the finalizer might not have run, then the WarnIfNotDisposed throws a warning
+   */
+  private[mxnet] def isDeAllocated(): Boolean = NativeResourceRef.isDeAllocated(ref)
+
+}
+
+private[mxnet] object NativeResource {
+  var totalBytesAllocated : AtomicLong = new AtomicLong(0)
+}
+
+// Do not make [[NativeResource.resource]] a member of the class,
+// this will hold reference and GC will not clear the object.
+private[mxnet] class NativeResourceRef(resource: NativeResource,
+                                       val resourceDeAllocator: CPtrAddress => Int)
+        extends PhantomReference[NativeResource](resource, NativeResourceRef.refQ) {}
+
+private[mxnet] object NativeResourceRef {
+
+  private[mxnet] val refQ: ReferenceQueue[NativeResource]
+                = new ReferenceQueue[NativeResource]
+
+  private[mxnet] val refMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]()
+
+  private[mxnet] val cleaner = new ResourceCleanupThread()
+
+  cleaner.start()
+
+  def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => Int)):
+  NativeResourceRef = {
+    val ref = new NativeResourceRef(resource, nativeDeAllocator)
+    refMap.put(ref, resource.nativeAddress)
+    ref
+  }
+
+  // remove from PhantomRef tracking
+  def deRegister(ref: NativeResourceRef): Unit = refMap.remove(ref)
+
+  /**
+    * This method will check if the cleaner ran and deAllocated the object
+    * As a part of GC, when the object is unreachable GC inserts a phantomRef
+    * to the ReferenceQueue which the cleaner thread will deallocate, however
+    * the finalizer runs much later depending on the GC.
+    * @param resource resource to verify if it has been deAllocated
+    * @return true if already deAllocated
+    */
+  def isDeAllocated(ref: NativeResourceRef): Boolean = {
+    !refMap.containsKey(ref)
+  }
+
+  def cleanup: Unit = {
+    // remove is a blocking call
+    val ref: NativeResourceRef = refQ.remove().asInstanceOf[NativeResourceRef]
+    // phantomRef will be removed from the map when NativeResource.close is called.
+    val resource = refMap.get(ref)
+    if (resource != 0L)  { // since CPtrAddress is Scala a Long, it cannot be null
+      ref.resourceDeAllocator(resource)
+      refMap.remove(ref)
+    }
+  }
+
+  protected class ResourceCleanupThread extends Thread {
+    setPriority(Thread.MAX_PRIORITY)
+    setName("NativeResourceDeAllocatorThread")
+    setDaemon(true)
+
+    override def run(): Unit = {
+      while (true) {
+        try {
+          NativeResourceRef.cleanup
+        }
+        catch {
+          case _: InterruptedException => Thread.currentThread().interrupt()
+        }
+      }
+    }
+  }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
index 758cbc82961..c3f8aaec6d6 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
@@ -19,6 +19,8 @@ package org.apache.mxnet
 
 import java.io._
 
+import org.apache.mxnet.Base.CPtrAddress
+
 import scala.collection.mutable
 import scala.util.Either
 
@@ -38,8 +40,10 @@ object Optimizer {
       }
 
       override def dispose(): Unit = {
-        states.values.foreach(optimizer.disposeState)
-        states.clear()
+        if (!super.isDisposed) {
+          states.values.foreach(optimizer.disposeState)
+          states.clear()
+        }
       }
 
       override def serializeState(): Array[Byte] = {
@@ -285,7 +289,8 @@ abstract class Optimizer extends Serializable {
   }
 }
 
-trait MXKVStoreUpdater {
+trait MXKVStoreUpdater extends
+  NativeResource {
   /**
    * user-defined updater for the kvstore
    * It's this updater's responsibility to delete recv and local
@@ -294,9 +299,14 @@ trait MXKVStoreUpdater {
    * @param local the value stored on local on this key
    */
   def update(key: Int, recv: NDArray, local: NDArray): Unit
-  def dispose(): Unit
-  // def serializeState(): Array[Byte]
-  // def deserializeState(bytes: Array[Byte]): Unit
+
+  // This is a hack to make Optimizers work with ResourceScope
+  // otherwise the user has to manage calling dispose on this object.
+  override def nativeAddress: CPtrAddress = hashCode()
+  override def nativeDeAllocator: CPtrAddress => Int = doNothingDeAllocator
+  private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0
+  override val ref: NativeResourceRef = super.register()
+  override val bytesAllocated: Long = 0L
 }
 
 trait MXKVStoreCachedStates {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
new file mode 100644
index 00000000000..1c5782d873a
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.util.HashSet
+
+import org.slf4j.LoggerFactory
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Try
+import scala.util.control.{ControlThrowable, NonFatal}
+
+/**
+  * This class manages automatically releasing of [[NativeResource]]s
+  */
+class ResourceScope extends AutoCloseable {
+
+  // HashSet does not take a custom comparator
+  private[mxnet] val resourceQ = new mutable.TreeSet[NativeResource]()(nativeAddressOrdering)
+
+  private object nativeAddressOrdering extends Ordering[NativeResource] {
+    def compare(a: NativeResource, b: NativeResource): Int = {
+      a.nativeAddress compare  b.nativeAddress
+    }
+  }
+
+  ResourceScope.addToThreadLocal(this)
+
+  /**
+    * Releases all the [[NativeResource]] by calling
+    * the associated [[NativeResource.close()]] method
+    */
+  override def close(): Unit = {
+    ResourceScope.removeFromThreadLocal(this)
+    resourceQ.foreach(resource => if (resource != null) resource.dispose(false) )
+    resourceQ.clear()
+  }
+
+  /**
+    * Add a NativeResource to the scope
+    * @param resource
+    */
+  def add(resource: NativeResource): Unit = {
+    resourceQ.+=(resource)
+  }
+
+  /**
+    * Remove NativeResource from the Scope, this uses
+    * object equality to find the resource in the stack.
+    * @param resource
+    */
+  def remove(resource: NativeResource): Unit = {
+    resourceQ.-=(resource)
+  }
+}
+
+object ResourceScope {
+
+  private val logger = LoggerFactory.getLogger(classOf[ResourceScope])
+
+  /**
+    * Captures all Native Resources created using the ResourceScope and
+    * at the end of the body, de allocates all the Native resources by calling close on them.
+    * This method will not deAllocate NativeResources returned from the block.
+    * @param scope (Optional). Scope in which to capture the native resources
+    * @param body  block of code to execute in this scope
+    * @tparam A return type
+    * @return result of the operation, if the result is of type NativeResource, it is not
+    *         de allocated so the user can use it and then de allocate manually by calling
+    *         close or enclose in another resourceScope.
+    */
+  // inspired from slide 21 of https://www.slideshare.net/Odersky/fosdem-2009-1013261
+  // and https://github.com/scala/scala/blob/2.13.x/src/library/scala/util/Using.scala
+  // TODO: we should move to the Scala util's Using method when we move to Scala 2.13
+  def using[A](scope: ResourceScope = null)(body: => A): A = {
+
+    val curScope = if (scope != null) scope else new ResourceScope()
+
+    val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
+
+    @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = {
+      g.foreach( n =>
+        n match {
+          case nRes: NativeResource => {
+            removeAndAddToPrevScope(nRes)
+          }
+          case kv: scala.Tuple2[_, _] => {
+            if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+              kv._1.asInstanceOf[NativeResource])
+            if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+              kv._2.asInstanceOf[NativeResource])
+          }
+        }
+      )
+    }
+
+    @inline def removeAndAddToPrevScope(r: NativeResource) = {
+      curScope.remove(r)
+      if (prevScope.isDefined)  {
+        prevScope.get.add(r)
+        r.scope = prevScope
+      }
+    }
+
+    @inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = {
+      if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed)
+    }
+
+    var retThrowable: Throwable = null
+
+    try {
+      val ret = body
+       ret match {
+          // don't de-allocate if returning any collection that contains NativeResource.
+        case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric)
+        case nRes: NativeResource => removeAndAddToPrevScope(nRes)
+        case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => removeAndAddToPrevScope(nd) )
+        case _ => // do nothing
+      }
+      ret
+    } catch {
+      case t: Throwable =>
+        retThrowable = t
+        null.asInstanceOf[A] // we'll throw in finally
+    } finally {
+      var toThrow: Throwable = retThrowable
+      if (retThrowable eq null) curScope.close()
+      else {
+        try {
+          curScope.close
+        } catch {
+          case closeThrowable: Throwable =>
+            if (NonFatal(retThrowable) && !NonFatal(closeThrowable)) toThrow = closeThrowable
+            else safeAddSuppressed(retThrowable, closeThrowable)
+        } finally {
+          throw toThrow
+        }
+      }
+    }
+  }
+
+  // thread local Scopes
+  private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] {
+    override def initialValue(): ArrayBuffer[ResourceScope] =
+      new ArrayBuffer[ResourceScope]()
+  }
+
+  /**
+    * Add resource to current ThreadLocal DataStructure
+    * @param r ResourceScope to add.
+    */
+  private[mxnet] def addToThreadLocal(r: ResourceScope): Unit = {
+    threadLocalScopes.get() += r
+  }
+
+  /**
+    * Remove resource from current ThreadLocal DataStructure
+    * @param r ResourceScope to remove
+    */
+  private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = {
+    threadLocalScopes.get() -= r
+  }
+
+  /**
+    * Get the latest Scope in the stack
+    * @return
+    */
+  private[mxnet] def getCurrentScope(): Option[ResourceScope] = {
+    Try(Some(threadLocalScopes.get().last)).getOrElse(None)
+  }
+
+  /**
+    * Get the Last but one Scope from threadLocal Scopes.
+    * @return n-1th scope or None when not found
+    */
+  private[mxnet] def getPrevScope(): Option[ResourceScope] = {
+    val scopes = threadLocalScopes.get()
+    Try(Some(scopes(scopes.size - 2))).getOrElse(None)
+  }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index b1a3e392f41..a009e7e343f 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -29,21 +29,15 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
  * WARNING: it is your responsibility to clear this object through dispose().
  * </b>
  */
-class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotDisposed {
+class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeResource {
   private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol])
-  private var disposed = false
-  protected def isDisposed = disposed
 
-  /**
-   * Release the native memory.
-   * The object shall never be used after it is disposed.
-   */
-  def dispose(): Unit = {
-    if (!disposed) {
-      _LIB.mxSymbolFree(handle)
-      disposed = true
-    }
-  }
+  // unable to get the byteAllocated for Symbol
+  override val bytesAllocated: Long = 0L
+  override def nativeAddress: CPtrAddress = handle
+  override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxSymbolFree
+  override val ref: NativeResourceRef = super.register()
+
 
   def +(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Plus")(Array(this, other))
   def +[@specialized(Int, Float, Double) V](other: V): Symbol = {
@@ -793,7 +787,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
     }
 
     val execHandle = new ExecutorHandleRef
-    val sharedHadle = if (sharedExec != null) sharedExec.handle else 0L
+    val sharedHandle = if (sharedExec != null) sharedExec.handle else 0L
     checkCall(_LIB.mxExecutorBindEX(handle,
                                    ctx.deviceTypeid,
                                    ctx.deviceId,
@@ -806,7 +800,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
                                    argsGradHandle,
                                    reqsArray,
                                    auxArgsHandle,
-                                   sharedHadle,
+                                   sharedHandle,
                                    execHandle))
     val executor = new Executor(execHandle.value, this.clone())
     executor.argArrays = argsNDArray
@@ -832,6 +826,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
     checkCall(_LIB.mxSymbolSaveToJSON(handle, jsonStr))
     jsonStr.value
   }
+
 }
 
 /**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index f7f858deb82..998017750db 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -33,7 +33,7 @@ import scala.collection.mutable.ListBuffer
 private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
                                 dataName: String = "data",
                                 labelName: String = "label")
-  extends DataIter with WarnIfNotDisposed {
+  extends DataIter with NativeResource {
 
   private val logger = LoggerFactory.getLogger(classOf[MXDataIter])
 
@@ -67,20 +67,13 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
     }
   }
 
+  override def nativeAddress: CPtrAddress = handle
 
-  private var disposed = false
-  protected def isDisposed = disposed
+  override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxDataIterFree
 
-  /**
-   * Release the native memory.
-   * The object shall never be used after it is disposed.
-   */
-  def dispose(): Unit = {
-    if (!disposed) {
-      _LIB.mxDataIterFree(handle)
-      disposed = true
-    }
-  }
+  override val ref: NativeResourceRef = super.register()
+
+  override val bytesAllocated: Long = 0L
 
   /**
    * reset the iterator
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
index e20b433ed1e..d349feac3e9 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
@@ -17,7 +17,7 @@
 
 package org.apache.mxnet.optimizer
 
-import org.apache.mxnet.{Optimizer, LRScheduler, NDArray}
+import org.apache.mxnet._
 import org.apache.mxnet.NDArrayConversions._
 
 /**
@@ -92,7 +92,13 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
     if (momentum == 0.0f) {
       null
     } else {
-      NDArray.zeros(weight.shape, weight.context)
+      val s = NDArray.zeros(weight.shape, weight.context)
+      // this is created on the fly and shared between runs,
+      // we don't want it to be dispose from the scope
+      // and should be handled by the dispose
+      val scope = ResourceScope.getCurrentScope()
+      if (scope.isDefined) scope.get.remove(s)
+      s
     }
   }
 
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala
new file mode 100644
index 00000000000..81a9f605a88
--- /dev/null
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.lang.ref.ReferenceQueue
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.mxnet.Base.CPtrAddress
+import org.mockito.Matchers.any
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, TagAnnotation}
+import org.mockito.Mockito._
+
+@TagAnnotation("resource")
+class NativeResourceSuite extends FunSuite with BeforeAndAfterAll with Matchers {
+
+  object TestRef  {
+    def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ}
+    def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress]
+    = {NativeResourceRef.refMap}
+    def getCleaner: Thread = { NativeResourceRef.cleaner }
+  }
+
+  class TestRef(resource: NativeResource,
+                          resourceDeAllocator: CPtrAddress => Int)
+    extends NativeResourceRef(resource, resourceDeAllocator) {
+  }
+
+  test(testName = "test native resource setup/teardown") {
+    val a = spy(NDArray.ones(Shape(2, 3)))
+    val aRef = a.ref
+    val spyRef = spy(aRef)
+
+    assert(TestRef.getRefMap.containsKey(aRef) == true)
+    a.close()
+    verify(a).dispose()
+    verify(a).nativeDeAllocator
+    // resourceDeAllocator does not get called when explicitly closing
+    verify(spyRef, times(0)).resourceDeAllocator
+
+    assert(TestRef.getRefMap.containsKey(aRef) == false)
+    assert(a.isDisposed == true, "isDisposed should be set to true after calling close")
+  }
+
+  test(testName = "test dispose") {
+    val a: NDArray = spy(NDArray.ones(Shape(3, 4)))
+    val aRef = a.ref
+    val spyRef = spy(aRef)
+    a.dispose()
+    verify(a).nativeDeAllocator
+    assert(TestRef.getRefMap.containsKey(aRef) == false)
+    assert(a.isDisposed == true, "isDisposed should be set to true after calling close")
+  }
+}
+
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
new file mode 100644
index 00000000000..41dfa7d0ead
--- /dev/null
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.lang.ref.ReferenceQueue
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.mxnet.Base.CPtrAddress
+import org.apache.mxnet.ResourceScope.logger
+import org.mockito.Matchers.any
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+import org.mockito.Mockito._
+import scala.collection.mutable.HashMap
+
+class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers {
+
+  class TestNativeResource extends NativeResource {
+    /**
+      * native Address associated with this object
+      */
+    override def nativeAddress: CPtrAddress = hashCode()
+
+    /**
+      * Function Pointer to the NativeDeAllocator of nativeAddress
+      */
+    override def nativeDeAllocator: CPtrAddress => Int = TestNativeResource.deAllocator
+
+    /** Call NativeResource.register to get the reference
+      */
+    override val ref: NativeResourceRef = super.register()
+    /**
+      * Off-Heap Bytes Allocated for this object
+      */
+    override val bytesAllocated: Long = 0
+  }
+  object TestNativeResource {
+    def deAllocator(handle: CPtrAddress): Int = 0
+  }
+
+  object TestPhantomRef  {
+    def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ}
+    def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress]
+    = {NativeResourceRef.refMap}
+    def getCleaner: Thread = { NativeResourceRef.cleaner }
+
+  }
+
+  class TestPhantomRef(resource: NativeResource,
+                       resourceDeAllocator: CPtrAddress => Int)
+    extends NativeResourceRef(resource, resourceDeAllocator) {
+  }
+
+  test(testName = "test NDArray Auto Release") {
+    var a: NDArray = null
+    var aRef: NativeResourceRef = null
+    var b: NDArray = null
+
+    ResourceScope.using() {
+      b = ResourceScope.using() {
+          a = NDArray.ones(Shape(3, 4))
+          aRef = a.ref
+          val x = NDArray.ones(Shape(3, 4))
+        x
+      }
+      val bRef: NativeResourceRef = b.ref
+      assert(a.isDisposed == true,
+        "objects created within scope should have isDisposed set to true")
+      assert(b.isDisposed == false,
+        "returned NativeResource should not be released")
+      assert(TestPhantomRef.getRefMap.containsKey(aRef) == false,
+        "reference of resource in Scope should be removed refMap")
+      assert(TestPhantomRef.getRefMap.containsKey(bRef) == true,
+        "reference of resource outside scope should be not removed refMap")
+    }
+    assert(b.isDisposed, "resource returned from inner scope should be released in outer scope")
+  }
+
+  test("test return object release from outer scope") {
+    var a: TestNativeResource = null
+    ResourceScope.using() {
+      a = ResourceScope.using() {
+        new TestNativeResource()
+      }
+      assert(a.isDisposed == false, "returned object should not be disposed within Using")
+    }
+    assert(a.isDisposed == true, "returned object should be disposed in the outer scope")
+  }
+
+  test(testName = "test NativeResources in returned Lists are not disposed") {
+    var ndListRet: IndexedSeq[TestNativeResource] = null
+    ResourceScope.using() {
+      ndListRet = ResourceScope.using() {
+        val ndList: IndexedSeq[TestNativeResource] =
+          IndexedSeq(new TestNativeResource(), new TestNativeResource())
+        ndList
+      }
+      ndListRet.foreach(nd => assert(nd.isDisposed == false,
+        "NativeResources within a returned collection should not be disposed"))
+    }
+    ndListRet.foreach(nd => assert(nd.isDisposed == true,
+    "NativeResources returned from inner scope should be disposed in outer scope"))
+  }
+
+  test("test native resource inside a map") {
+    var nRInKeyOfMap: HashMap[TestNativeResource, String] = null
+    var nRInValOfMap: HashMap[String, TestNativeResource] = HashMap[String, TestNativeResource]()
+
+    ResourceScope.using() {
+      nRInKeyOfMap = ResourceScope.using() {
+        val ret = HashMap[TestNativeResource, String]()
+        ret.put(new TestNativeResource, "hello")
+        ret
+      }
+      assert(!nRInKeyOfMap.isEmpty)
+
+      nRInKeyOfMap.keysIterator.foreach(it => assert(it.isDisposed == false,
+      "NativeResources returned in Traversable should not be disposed"))
+    }
+
+    nRInKeyOfMap.keysIterator.foreach(it => assert(it.isDisposed))
+
+    ResourceScope.using() {
+
+      nRInValOfMap = ResourceScope.using() {
+        val ret = HashMap[String, TestNativeResource]()
+        ret.put("world!", new TestNativeResource)
+        ret
+      }
+      assert(!nRInValOfMap.isEmpty)
+      nRInValOfMap.valuesIterator.foreach(it => assert(it.isDisposed == false,
+        "NativeResources returned in Collection should not be disposed"))
+    }
+    nRInValOfMap.valuesIterator.foreach(it => assert(it.isDisposed))
+  }
+
+}
diff --git a/scala-package/examples/scripts/run_train_mnist.sh b/scala-package/examples/scripts/run_train_mnist.sh
new file mode 100755
index 00000000000..ea53c1ade66
--- /dev/null
+++ b/scala-package/examples/scripts/run_train_mnist.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd)
+echo $MXNET_ROOT
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
+
+# model dir
+DATA_PATH=$2
+
+java -XX:+PrintGC -Xms256M -Xmx512M -Dmxnet.traceLeakedObjects=false -cp $CLASS_PATH \
+        org.apache.mxnetexamples.imclassification.TrainMnist \
+        --data-dir /home/ubuntu/mxnet_scala/scala-package/examples/mnist/ \
+        --num-epochs 10000000 \
+        --batch-size 1024
\ No newline at end of file


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services