You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2019/03/28 18:57:37 UTC

[incubator-mxnet] branch master updated: Memory fixes. Resolves #10867, and resolves #14080 (#14372)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 102b46f  Memory fixes. Resolves #10867, and resolves #14080 (#14372)
102b46f is described below

commit 102b46feb5bf1061545ac79e4b114a180250740e
Author: Andrew Ayres <an...@gmail.com>
AuthorDate: Thu Mar 28 11:57:09 2019 -0700

    Memory fixes. Resolves #10867, and resolves #14080 (#14372)
    
    * Fixes for memory leak when reshaping executor
    
    * Fixed Adam Optimizer memory leak
    
    * Cleanup for PR
    
    * Added unit test for new ResourceScope method
    
    * Removing import that was added by overzealous ide
    
    * Add back in an import
    
    * Added flags for executor to know whether or not it owns NDArrays for disposal
    
    * Moving to ResourceScope.using implementation
    
    * Changes to make ResourceScope.using work with existing scope
    
    * Updating ResourceScope to work with existing scopes via usingIfScopeExists method
    
    * Fix clojure unit tests
    
    * Fixes to be compatibile with how clojure is using ResourceScope
    
    * Removing some unnecessary changes
    
    * Adding scope assertion in unit test
---
 .../src/main/scala/org/apache/mxnet/Executor.scala |  47 ++++++++--
 .../main/scala/org/apache/mxnet/Optimizer.scala    |  20 ++--
 .../scala/org/apache/mxnet/ResourceScope.scala     |  21 ++++-
 .../src/main/scala/org/apache/mxnet/Symbol.scala   |  21 +++--
 .../mxnet/module/DataParallelExecutorGroup.scala   |  11 ++-
 .../scala/org/apache/mxnet/optimizer/Adam.scala    | 101 ++++++++++-----------
 .../org/apache/mxnet/ResourceScopeSuite.scala      |  33 +++++++
 7 files changed, 165 insertions(+), 89 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index 85f45bc..aec4402 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -45,29 +45,47 @@ object Executor {
  * @see Symbol.bind : to create executor
  */
 class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
-                              private[mxnet] val symbol: Symbol) extends NativeResource {
-  private[mxnet] var argArrays: Array[NDArray] = null
-  private[mxnet] var gradArrays: Array[NDArray] = null
-  private[mxnet] var auxArrays: Array[NDArray] = null
+                              private[mxnet] val symbol: Symbol,
+                              private[mxnet] var argArrays: Array[NDArray] = null,
+                              private[mxnet] var gradArrays: Array[NDArray] = null,
+                              private[mxnet] var auxArrays: Array[NDArray] = null,
+                              private var _ctx: Context = null,
+                              private var _gradsReq: Iterable[_] = null,
+                              private var _group2ctx: Map[String, Context] = null
+                             ) extends NativeResource {
+
   val outputs: Array[NDArray] = getOutputs
   protected var _argDict: Map[String, NDArray] = null
   protected var _gradDict: Map[String, NDArray] = null
   protected var _auxDict: Map[String, NDArray] = null
   protected var monitorCallback: MXMonitorCallback = null
-  private[mxnet] var _ctx: Context = null
-  private[mxnet] var _gradsReq: Iterable[_] = null
-  private[mxnet] var _group2ctx: Map[String, Context] = null
   private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])
 
+  private[mxnet] var ownsArgArrays = false
+  private[mxnet] var ownsGradArrays = false
+  private[mxnet] var ownsAuxArrays = false
+
   override def nativeAddress: CPtrAddress = handle
   override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
   // cannot determine the off-heap size of this object
   override val bytesAllocated: Long = 0
   override val ref: NativeResourceRef = super.register()
+
   override def dispose(): Unit = {
     if (!super.isDisposed) {
       super.dispose()
       outputs.foreach(o => o.dispose())
+      // Symbol.bind clones symbol when creating the executor so we need to dispose of the clone
+      symbol.dispose()
+      if (ownsArgArrays && argArrays != null) {argArrays.foreach(a => a.dispose())}
+      if (ownsGradArrays && gradArrays != null) {gradArrays.foreach(
+        // Symbol will sometimes fill this with nulls so we've got to check the elements too
+        a => if (a != null) {a.dispose()})
+      }
+      if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a => a.dispose())}
+      if (_argDict != null) {_argDict.foreach(a => a._2.dispose())}
+      if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())}
+      if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())}
     }
   }
 
@@ -86,6 +104,9 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
    */
   def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
     kwargs: Map[String, Shape]): Executor = {
+    var setArgOwner = false
+    var setAuxOwner = false
+    var setGradOwner = false
      val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
     // TODO: more precise error message should be provided by backend
     require(argShapes != null, "Shape inference failed." +
@@ -107,8 +128,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
                         "If you really want to up size, set allowUpSizing = true " +
                         "to enable allocation of new arrays.")
           newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context, arr.dtype))
+          setArgOwner = true
           if (dArr != null) {
             newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context, dArr.dtype))
+            setGradOwner = true
           }
         } else {
           newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
@@ -135,6 +158,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
                         "If you really want to up size, set allowUpSizing = true " +
                         "to enable allocation of new arrays.")
           newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context))
+          setAuxOwner = true
         } else {
           newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray))
         }
@@ -145,7 +169,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
                   "If this is intended, set partialShaping = true to suppress this warning.")
       }
     }
-    if (this._gradsReq.isInstanceOf[Seq[_]]) {
+    val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) {
       this.symbol.bind(this._ctx,
                           newArgDict,
                           newGradDict,
@@ -162,6 +186,13 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
                           this._group2ctx,
                           this)
     }
+
+    // This method has created new NDArrays that will need to be managed by the new Executor
+    if (setArgOwner) reshapedExecutor.ownsArgArrays = true
+    if (setGradOwner) reshapedExecutor.ownsGradArrays = true
+    if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true
+
+    reshapedExecutor
   }
 
   /**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
index da58976..123eae9 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
@@ -28,15 +28,17 @@ object Optimizer {
   def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = {
     new MXKVStoreUpdater with MXKVStoreCachedStates {
       override def update(index: Int, grad: NDArray, weight: NDArray): Unit = {
-        val state =
-          if (states.contains(index)) {
-            states.get(index).get
-          } else {
-            val newState = optimizer.createState(index, weight)
-            states.put(index, newState)
-            newState
-          }
-        optimizer.update(index, weight, grad, state)
+        ResourceScope.usingIfScopeExists(this.scope) {
+          val state =
+            if (states.contains(index)) {
+              states.get(index).get
+            } else {
+              val newState = optimizer.createState(index, weight)
+              states.put(index, newState)
+              newState
+            }
+          optimizer.update(index, weight, grad, state)
+        }
       }
 
       override def dispose(): Unit = {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
index bb363c0..b955c18 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
@@ -48,8 +48,10 @@ class ResourceScope extends AutoCloseable {
     */
   override def close(): Unit = {
     ResourceScope.removeFromThreadLocal(this)
-    resourceQ.foreach(resource => if (resource != null) resource.dispose(false) )
-    resourceQ.clear()
+    if (!ResourceScope.threadLocalScopes.get().contains(this)) {
+      resourceQ.foreach(resource => if (resource != null) resource.dispose(false))
+      resourceQ.clear()
+    }
   }
 
   /**
@@ -145,7 +147,7 @@ object ResourceScope {
         null.asInstanceOf[A] // we'll throw in finally
     } finally {
       var toThrow: Throwable = retThrowable
-      if (retThrowable eq null) curScope.close()
+      if (retThrowable eq null) curScope.close
       else {
         try {
           curScope.close
@@ -160,6 +162,17 @@ object ResourceScope {
     }
   }
 
+  private[mxnet] def usingIfScopeExists[A](scope: Option[ResourceScope])(body: => A): A = {
+    if (scope == None) {
+      body
+    } else {
+      ResourceScope.addToThreadLocal(scope.get)
+      ResourceScope.using(scope.get){
+        body
+      }
+    }
+  }
+
   // thread local Scopes
   private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] {
     override def initialValue(): ArrayBuffer[ResourceScope] =
@@ -179,7 +192,7 @@ object ResourceScope {
     * @param r ResourceScope to remove
     */
   private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = {
-    threadLocalScopes.get() -= r
+    threadLocalScopes.get().remove(threadLocalScopes.get().lastIndexOf(r))
   }
 
   /**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 29885fc..821e04f 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -803,18 +803,23 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso
                                    auxArgsHandle,
                                    sharedHandle,
                                    execHandle))
-    val executor = new Executor(execHandle.value, this.clone())
-    executor.argArrays = argsNDArray
-    executor.gradArrays = argsGradNDArray
-    executor.auxArrays = auxStatesNDArray
-    executor._ctx = new Context(ctx.deviceType, ctx.deviceId)
-    executor._gradsReq = gradsReq
-    executor._group2ctx =
+
+    val executorGroup2ctx =
       if (group2ctx == null) null
       else group2ctx.map { case (key, value) =>
         key -> new Context(value.deviceType, value.deviceId)
       }
-    executor
+
+    // If this is in a scope then we want to create the clone in the same scope
+    var newSymbol: Symbol = null
+    ResourceScope.usingIfScopeExists(this.scope) {
+      newSymbol = this.clone()
+    }
+
+    new Executor(execHandle.value, newSymbol, argsNDArray, argsGradNDArray,
+                                auxStatesNDArray, new Context(ctx.deviceType, ctx.deviceId),
+                                gradsReq, executorGroup2ctx)
+
   }
 
   /**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
index df66ea7..74e63be 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
@@ -299,7 +299,6 @@ class DataParallelExecutorGroup private[module](
 
   private var batchSize: Int = -1
   private var slices: Array[(Int, Int)] = null
-  private var _defaultExecs: Array[Executor] = null
   private var execs: Array[Executor] = null
   private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null
   private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None
@@ -373,7 +372,12 @@ class DataParallelExecutorGroup private[module](
         val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts))
         val inputShapes
           = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape])
-        execs(i) = _defaultExecs(i).reshape(allowUpSizing = true, kwargs = inputShapes)
+
+        ResourceScope.usingIfScopeExists(execs(i).scope) {
+          val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes)
+          execs(i).dispose()
+          execs(i) = tmpExec
+        }
       }
     } else {
       execs = (0 until contexts.length).map(i =>
@@ -434,9 +438,6 @@ class DataParallelExecutorGroup private[module](
    */
   def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = {
     if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) {
-      if (this._defaultExecs == null) {
-        this._defaultExecs = this.execs.map(x => x)
-      }
       this.bindExec(dataShapes, labelShapes, None, reshape = true)
     }
   }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala
index 24f3323..5a8b3cb 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala
@@ -19,7 +19,7 @@ package org.apache.mxnet.optimizer
 
 import org.apache.mxnet.NDArrayConversions._
 import org.apache.mxnet.util.SerializerUtils
-import org.apache.mxnet.{LRScheduler, NDArray, Optimizer}
+import org.apache.mxnet.{LRScheduler, NDArray, Optimizer, ResourceScope}
 
 /**
  * Adam optimizer as described in [King2014]
@@ -57,63 +57,54 @@ class Adam(val learningRate: Float = 0.002f, beta1: Float = 0.9f, beta2: Float =
    *              The auxiliary state used in optimization.
    */
   override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = {
-    var lr =
-      (if (lrScheduler != null) {
-        val scheduledLr = lrScheduler(numUpdate)
-        updateCount(index)
-        scheduledLr
-      } else {
-        this.learningRate
-      })
-    lr = getLr(index, lr)
-
-    val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]
-
-    // increment time only when the first parameters is called
-    timeFirstIndex match {
-      case Some(idx) =>
-        if (idx == index) time += 1
-      case None =>
-        timeFirstIndex = Option(index)
-        time = 0 // all parameters share the same time
-    }
-
-    val t1: Int = time + 1
-    val learningRate = (lr *
-      math.sqrt(1.0 - math.pow(beta2, t1)) /
-      (1.0 - math.pow(beta1, t1))).toFloat
-    val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat
-
-    var resdGrad = grad * rescaleGrad
-    if (clipGradient != 0f) {
-      val oldResdGrad = resdGrad
-      resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
-      oldResdGrad.dispose()
-    }
-
-    val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad)
-      .disposeDepsExcept(mean, resdGrad)
-    val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad)
-      .disposeDepsExcept(variance, resdGrad)
+    ResourceScope.using() {
+      var lr =
+        (if (lrScheduler != null) {
+          val scheduledLr = lrScheduler(numUpdate)
+          updateCount(index)
+          scheduledLr
+        } else {
+          this.learningRate
+        })
+      lr = getLr(index, lr)
 
-    val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon))
-      .disposeDepsExcept(meanT, varianceT)
+      val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]
 
-    val wd = this.getWd(index, this.wd)
-    if (wd > 0.0f) {
-      val stepDelta = lr * wd * weight
-      step += stepDelta
-      stepDelta.dispose()
+      // increment time only when the first parameters is called
+      timeFirstIndex match {
+        case Some(idx) =>
+          if (idx == index) time += 1
+        case None =>
+          timeFirstIndex = Option(index)
+          time = 0 // all parameters share the same time
+      }
+
+      val t1: Int = time + 1
+      val learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1)) /
+        (1.0 - math.pow(beta1, t1))).toFloat
+      val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat
+
+      var resdGrad = grad * rescaleGrad
+      if (clipGradient != 0f) {
+        val oldResdGrad = resdGrad
+        resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
+      }
+
+      val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad)
+      val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad)
+      val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon))
+
+      val wd = this.getWd(index, this.wd)
+      if (wd > 0.0f) {
+        val stepDelta = lr * wd * weight
+        step += stepDelta
+      }
+
+      weight -= step
+      mean.set(meanT)
+      variance.set(varianceT)
+      (mean, variance)
     }
-
-    weight -= step
-    mean.set(meanT)
-    variance.set(varianceT)
-
-    meanT.dispose()
-    varianceT.dispose()
-    step.dispose()
-    resdGrad.dispose()
   }
 
   // Create additional optimizer state: mean, variance
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
index 41dfa7d..1916238 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
@@ -101,6 +101,39 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers {
     assert(a.isDisposed == true, "returned object should be disposed in the outer scope")
   }
 
+  /**
+    * Tests passing a scope to using and creating new resources within.
+    */
+  test("test moving scope of native resource to scope of another") {
+    var a: TestNativeResource = null
+    var b: TestNativeResource = null
+    var c: TestNativeResource = null
+    var d: TestNativeResource = null
+
+    ResourceScope.using() {
+      a = new TestNativeResource()
+      ResourceScope.using() {
+        b = new TestNativeResource()
+        ResourceScope.usingIfScopeExists(a.scope) {
+          c = new TestNativeResource()
+          ResourceScope.using() {
+            d = new TestNativeResource()
+            assert(c.scope == a.scope)
+          }
+          assert(d.isDisposed == true)
+        }
+        assert(b.isDisposed == false)
+        assert(c.isDisposed == false)
+      }
+      assert(a.isDisposed == false)
+      assert(b.isDisposed == true)
+      assert(c.isDisposed == false)
+    }
+    assert(a.isDisposed == true)
+    assert(b.isDisposed == true)
+    assert(c.isDisposed == true)
+  }
+
   test(testName = "test NativeResources in returned Lists are not disposed") {
     var ndListRet: IndexedSeq[TestNativeResource] = null
     ResourceScope.using() {