You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2019/03/28 18:57:37 UTC
[incubator-mxnet] branch master updated: Memory fixes. Resolves
#10867, and resolves #14080 (#14372)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 102b46f Memory fixes. Resolves #10867, and resolves #14080 (#14372)
102b46f is described below
commit 102b46feb5bf1061545ac79e4b114a180250740e
Author: Andrew Ayres <an...@gmail.com>
AuthorDate: Thu Mar 28 11:57:09 2019 -0700
Memory fixes. Resolves #10867, and resolves #14080 (#14372)
* Fixes for memory leak when reshaping executor
* Fixed Adam Optimizer memory leak
* Cleanup for PR
* Added unit test for new ResourceScope method
* Removing import that was added by overzealous ide
* Add back in an import
* Added flags for executor to know whether or not it owns NDArrays for disposal
* Moving to ResourceScope.using implementation
* Changes to make ResourceScope.using work with existing scope
* Updating ResourceScope to work with existing scopes via usingIfScopeExists method
* Fix clojure unit tests
* Fixes to be compatibile with how clojure is using ResourceScope
* Removing some unnecessary changes
* Adding scope assertion in unit test
---
.../src/main/scala/org/apache/mxnet/Executor.scala | 47 ++++++++--
.../main/scala/org/apache/mxnet/Optimizer.scala | 20 ++--
.../scala/org/apache/mxnet/ResourceScope.scala | 21 ++++-
.../src/main/scala/org/apache/mxnet/Symbol.scala | 21 +++--
.../mxnet/module/DataParallelExecutorGroup.scala | 11 ++-
.../scala/org/apache/mxnet/optimizer/Adam.scala | 101 ++++++++++-----------
.../org/apache/mxnet/ResourceScopeSuite.scala | 33 +++++++
7 files changed, 165 insertions(+), 89 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index 85f45bc..aec4402 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -45,29 +45,47 @@ object Executor {
* @see Symbol.bind : to create executor
*/
class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
- private[mxnet] val symbol: Symbol) extends NativeResource {
- private[mxnet] var argArrays: Array[NDArray] = null
- private[mxnet] var gradArrays: Array[NDArray] = null
- private[mxnet] var auxArrays: Array[NDArray] = null
+ private[mxnet] val symbol: Symbol,
+ private[mxnet] var argArrays: Array[NDArray] = null,
+ private[mxnet] var gradArrays: Array[NDArray] = null,
+ private[mxnet] var auxArrays: Array[NDArray] = null,
+ private var _ctx: Context = null,
+ private var _gradsReq: Iterable[_] = null,
+ private var _group2ctx: Map[String, Context] = null
+ ) extends NativeResource {
+
val outputs: Array[NDArray] = getOutputs
protected var _argDict: Map[String, NDArray] = null
protected var _gradDict: Map[String, NDArray] = null
protected var _auxDict: Map[String, NDArray] = null
protected var monitorCallback: MXMonitorCallback = null
- private[mxnet] var _ctx: Context = null
- private[mxnet] var _gradsReq: Iterable[_] = null
- private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])
+ private[mxnet] var ownsArgArrays = false
+ private[mxnet] var ownsGradArrays = false
+ private[mxnet] var ownsAuxArrays = false
+
override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
override val bytesAllocated: Long = 0
override val ref: NativeResourceRef = super.register()
+
override def dispose(): Unit = {
if (!super.isDisposed) {
super.dispose()
outputs.foreach(o => o.dispose())
+ // Symbol.bind clones symbol when creating the executor so we need to dispose of the clone
+ symbol.dispose()
+ if (ownsArgArrays && argArrays != null) {argArrays.foreach(a => a.dispose())}
+ if (ownsGradArrays && gradArrays != null) {gradArrays.foreach(
+ // Symbol will sometimes fill this with nulls so we've got to check the elements too
+ a => if (a != null) {a.dispose()})
+ }
+ if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a => a.dispose())}
+ if (_argDict != null) {_argDict.foreach(a => a._2.dispose())}
+ if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())}
+ if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())}
}
}
@@ -86,6 +104,9 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
*/
def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
kwargs: Map[String, Shape]): Executor = {
+ var setArgOwner = false
+ var setAuxOwner = false
+ var setGradOwner = false
val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
// TODO: more precise error message should be provided by backend
require(argShapes != null, "Shape inference failed." +
@@ -107,8 +128,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context, arr.dtype))
+ setArgOwner = true
if (dArr != null) {
newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context, dArr.dtype))
+ setGradOwner = true
}
} else {
newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
@@ -135,6 +158,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context))
+ setAuxOwner = true
} else {
newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray))
}
@@ -145,7 +169,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If this is intended, set partialShaping = true to suppress this warning.")
}
}
- if (this._gradsReq.isInstanceOf[Seq[_]]) {
+ val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) {
this.symbol.bind(this._ctx,
newArgDict,
newGradDict,
@@ -162,6 +186,13 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
this._group2ctx,
this)
}
+
+ // This method has created new NDArrays that will need to be managed by the new Executor
+ if (setArgOwner) reshapedExecutor.ownsArgArrays = true
+ if (setGradOwner) reshapedExecutor.ownsGradArrays = true
+ if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true
+
+ reshapedExecutor
}
/**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
index da58976..123eae9 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
@@ -28,15 +28,17 @@ object Optimizer {
def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = {
new MXKVStoreUpdater with MXKVStoreCachedStates {
override def update(index: Int, grad: NDArray, weight: NDArray): Unit = {
- val state =
- if (states.contains(index)) {
- states.get(index).get
- } else {
- val newState = optimizer.createState(index, weight)
- states.put(index, newState)
- newState
- }
- optimizer.update(index, weight, grad, state)
+ ResourceScope.usingIfScopeExists(this.scope) {
+ val state =
+ if (states.contains(index)) {
+ states.get(index).get
+ } else {
+ val newState = optimizer.createState(index, weight)
+ states.put(index, newState)
+ newState
+ }
+ optimizer.update(index, weight, grad, state)
+ }
}
override def dispose(): Unit = {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
index bb363c0..b955c18 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
@@ -48,8 +48,10 @@ class ResourceScope extends AutoCloseable {
*/
override def close(): Unit = {
ResourceScope.removeFromThreadLocal(this)
- resourceQ.foreach(resource => if (resource != null) resource.dispose(false) )
- resourceQ.clear()
+ if (!ResourceScope.threadLocalScopes.get().contains(this)) {
+ resourceQ.foreach(resource => if (resource != null) resource.dispose(false))
+ resourceQ.clear()
+ }
}
/**
@@ -145,7 +147,7 @@ object ResourceScope {
null.asInstanceOf[A] // we'll throw in finally
} finally {
var toThrow: Throwable = retThrowable
- if (retThrowable eq null) curScope.close()
+ if (retThrowable eq null) curScope.close
else {
try {
curScope.close
@@ -160,6 +162,17 @@ object ResourceScope {
}
}
+ private[mxnet] def usingIfScopeExists[A](scope: Option[ResourceScope])(body: => A): A = {
+ if (scope == None) {
+ body
+ } else {
+ ResourceScope.addToThreadLocal(scope.get)
+ ResourceScope.using(scope.get){
+ body
+ }
+ }
+ }
+
// thread local Scopes
private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] {
override def initialValue(): ArrayBuffer[ResourceScope] =
@@ -179,7 +192,7 @@ object ResourceScope {
* @param r ResourceScope to remove
*/
private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = {
- threadLocalScopes.get() -= r
+ threadLocalScopes.get().remove(threadLocalScopes.get().lastIndexOf(r))
}
/**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 29885fc..821e04f 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -803,18 +803,23 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso
auxArgsHandle,
sharedHandle,
execHandle))
- val executor = new Executor(execHandle.value, this.clone())
- executor.argArrays = argsNDArray
- executor.gradArrays = argsGradNDArray
- executor.auxArrays = auxStatesNDArray
- executor._ctx = new Context(ctx.deviceType, ctx.deviceId)
- executor._gradsReq = gradsReq
- executor._group2ctx =
+
+ val executorGroup2ctx =
if (group2ctx == null) null
else group2ctx.map { case (key, value) =>
key -> new Context(value.deviceType, value.deviceId)
}
- executor
+
+ // If this is in a scope then we want to create the clone in the same scope
+ var newSymbol: Symbol = null
+ ResourceScope.usingIfScopeExists(this.scope) {
+ newSymbol = this.clone()
+ }
+
+ new Executor(execHandle.value, newSymbol, argsNDArray, argsGradNDArray,
+ auxStatesNDArray, new Context(ctx.deviceType, ctx.deviceId),
+ gradsReq, executorGroup2ctx)
+
}
/**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
index df66ea7..74e63be 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
@@ -299,7 +299,6 @@ class DataParallelExecutorGroup private[module](
private var batchSize: Int = -1
private var slices: Array[(Int, Int)] = null
- private var _defaultExecs: Array[Executor] = null
private var execs: Array[Executor] = null
private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null
private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None
@@ -373,7 +372,12 @@ class DataParallelExecutorGroup private[module](
val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts))
val inputShapes
= dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape])
- execs(i) = _defaultExecs(i).reshape(allowUpSizing = true, kwargs = inputShapes)
+
+ ResourceScope.usingIfScopeExists(execs(i).scope) {
+ val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes)
+ execs(i).dispose()
+ execs(i) = tmpExec
+ }
}
} else {
execs = (0 until contexts.length).map(i =>
@@ -434,9 +438,6 @@ class DataParallelExecutorGroup private[module](
*/
def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = {
if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) {
- if (this._defaultExecs == null) {
- this._defaultExecs = this.execs.map(x => x)
- }
this.bindExec(dataShapes, labelShapes, None, reshape = true)
}
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala
index 24f3323..5a8b3cb 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala
@@ -19,7 +19,7 @@ package org.apache.mxnet.optimizer
import org.apache.mxnet.NDArrayConversions._
import org.apache.mxnet.util.SerializerUtils
-import org.apache.mxnet.{LRScheduler, NDArray, Optimizer}
+import org.apache.mxnet.{LRScheduler, NDArray, Optimizer, ResourceScope}
/**
* Adam optimizer as described in [King2014]
@@ -57,63 +57,54 @@ class Adam(val learningRate: Float = 0.002f, beta1: Float = 0.9f, beta2: Float =
* The auxiliary state used in optimization.
*/
override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = {
- var lr =
- (if (lrScheduler != null) {
- val scheduledLr = lrScheduler(numUpdate)
- updateCount(index)
- scheduledLr
- } else {
- this.learningRate
- })
- lr = getLr(index, lr)
-
- val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]
-
- // increment time only when the first parameters is called
- timeFirstIndex match {
- case Some(idx) =>
- if (idx == index) time += 1
- case None =>
- timeFirstIndex = Option(index)
- time = 0 // all parameters share the same time
- }
-
- val t1: Int = time + 1
- val learningRate = (lr *
- math.sqrt(1.0 - math.pow(beta2, t1)) /
- (1.0 - math.pow(beta1, t1))).toFloat
- val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat
-
- var resdGrad = grad * rescaleGrad
- if (clipGradient != 0f) {
- val oldResdGrad = resdGrad
- resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
- oldResdGrad.dispose()
- }
-
- val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad)
- .disposeDepsExcept(mean, resdGrad)
- val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad)
- .disposeDepsExcept(variance, resdGrad)
+ ResourceScope.using() {
+ var lr =
+ (if (lrScheduler != null) {
+ val scheduledLr = lrScheduler(numUpdate)
+ updateCount(index)
+ scheduledLr
+ } else {
+ this.learningRate
+ })
+ lr = getLr(index, lr)
- val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon))
- .disposeDepsExcept(meanT, varianceT)
+ val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]
- val wd = this.getWd(index, this.wd)
- if (wd > 0.0f) {
- val stepDelta = lr * wd * weight
- step += stepDelta
- stepDelta.dispose()
+ // increment time only when the first parameters is called
+ timeFirstIndex match {
+ case Some(idx) =>
+ if (idx == index) time += 1
+ case None =>
+ timeFirstIndex = Option(index)
+ time = 0 // all parameters share the same time
+ }
+
+ val t1: Int = time + 1
+ val learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1)) /
+ (1.0 - math.pow(beta1, t1))).toFloat
+ val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat
+
+ var resdGrad = grad * rescaleGrad
+ if (clipGradient != 0f) {
+ val oldResdGrad = resdGrad
+ resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
+ }
+
+ val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad)
+ val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad)
+ val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon))
+
+ val wd = this.getWd(index, this.wd)
+ if (wd > 0.0f) {
+ val stepDelta = lr * wd * weight
+ step += stepDelta
+ }
+
+ weight -= step
+ mean.set(meanT)
+ variance.set(varianceT)
+ (mean, variance)
}
-
- weight -= step
- mean.set(meanT)
- variance.set(varianceT)
-
- meanT.dispose()
- varianceT.dispose()
- step.dispose()
- resdGrad.dispose()
}
// Create additional optimizer state: mean, variance
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
index 41dfa7d..1916238 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
@@ -101,6 +101,39 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(a.isDisposed == true, "returned object should be disposed in the outer scope")
}
+ /**
+ * Tests passing a scope to using and creating new resources within.
+ */
+ test("test moving scope of native resource to scope of another") {
+ var a: TestNativeResource = null
+ var b: TestNativeResource = null
+ var c: TestNativeResource = null
+ var d: TestNativeResource = null
+
+ ResourceScope.using() {
+ a = new TestNativeResource()
+ ResourceScope.using() {
+ b = new TestNativeResource()
+ ResourceScope.usingIfScopeExists(a.scope) {
+ c = new TestNativeResource()
+ ResourceScope.using() {
+ d = new TestNativeResource()
+ assert(c.scope == a.scope)
+ }
+ assert(d.isDisposed == true)
+ }
+ assert(b.isDisposed == false)
+ assert(c.isDisposed == false)
+ }
+ assert(a.isDisposed == false)
+ assert(b.isDisposed == true)
+ assert(c.isDisposed == false)
+ }
+ assert(a.isDisposed == true)
+ assert(b.isDisposed == true)
+ assert(c.isDisposed == true)
+ }
+
test(testName = "test NativeResources in returned Lists are not disposed") {
var ndListRet: IndexedSeq[TestNativeResource] = null
ResourceScope.using() {