You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/15 19:36:35 UTC

[GitHub] yzhliu closed pull request #13281: [MXNET-1214] clean up the NDArray

yzhliu closed pull request #13281: [MXNET-1214] clean up the NDArray
URL: https://github.com/apache/incubator-mxnet/pull/13281
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index 446df257e20..6b4f4bdebda 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -50,7 +50,7 @@ object NDArray extends NDArrayBase {
   = org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
   def empty(ctx: Context, shape: Array[Int]): NDArray
   = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
-  def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+  def empty(ctx: Context, shape: java.util.List[java.lang.Integer]): NDArray
   = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
 
   /**
@@ -65,7 +65,7 @@ object NDArray extends NDArrayBase {
   = org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
   def zeros(ctx: Context, shape: Array[Int]): NDArray
   = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
-  def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+  def zeros(ctx: Context, shape: java.util.List[java.lang.Integer]): NDArray
   = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
 
   /**
@@ -78,7 +78,7 @@ object NDArray extends NDArrayBase {
   = org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
   def ones(ctx: Context, shape: Array[Int]): NDArray
   = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
-  def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+  def ones(ctx: Context, shape: java.util.List[java.lang.Integer]): NDArray
   = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
 
   /**
@@ -193,37 +193,47 @@ object NDArray extends NDArrayBase {
   * NDArray is basic ndarray/Tensor like data structure in mxnet. <br />
   * <b>
   * NOTE: NDArray is stored in native memory. Use NDArray in a try-with-resources() construct
-  * or a [[ResourceScope]] in a try-with-resource to have them automatically disposed. You can
-  * explicitly control the lifetime of NDArray by calling dispose manually. Failure to do this
-  * will result in leaking native memory.
+  * or a [[org.apache.mxnet.ResourceScope]] in a try-with-resource to have them
+  * automatically disposed. You can explicitly control the lifetime of NDArray
+  * by calling dispose manually. Failure to do this will result in leaking native memory.
   * </b>
   */
-class NDArray private[mxnet] (val nd : org.apache.mxnet.NDArray ) {
+class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
 
-  def this(arr : Array[Float], shape : Shape, ctx : Context) = {
+  def this(arr: Array[Float], shape: Shape, ctx: Context) = {
     this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
   }
 
-  def this(arr : java.util.List[java.lang.Float], shape : Shape, ctx : Context) = {
+  def this(arr: java.util.List[java.lang.Float], shape: Shape, ctx: Context) = {
     this(NDArray.array(arr, shape, ctx))
   }
 
-  def serialize() : Array[Byte] = nd.serialize()
+  def serialize(): Array[Byte] = nd.serialize()
 
   /**
     * Release the native memory. <br />
     * The NDArrays it depends on will NOT be disposed. <br />
     * The object shall never be used after it is disposed.
     */
-  def dispose() : Unit = nd.dispose()
+  def dispose(): Unit = nd.dispose()
 
   /**
     * Dispose all NDArrays who help to construct this array. <br />
     * e.g. (a * b + c).disposeDeps() will dispose a, b, c (including their deps) and a * b
     * @return this array
     */
-  def disposeDeps() : NDArray = nd.disposeDepsExcept()
-  // def disposeDepsExcept(arr : Array[NDArray]) : NDArray = nd.disposeDepsExcept()
+  def disposeDeps(): NDArray = nd.disposeDepsExcept()
+
+  /**
+    * Dispose all NDArrays who help to construct this array, excepts those in the arguments. <br />
+    * e.g. (a * b + c).disposeDepsExcept(a, b)
+    * will dispose c and a * b.
+    * Note that a, b's dependencies will not be disposed either.
+    * @param arr the Array of NDArray not to dispose
+    * @return this array
+    */
+  def disposeDepsExcept(arr: Array[NDArray]): NDArray =
+    nd.disposeDepsExcept(arr.map(NDArray.toNDArray): _*)
 
   /**
     * Return a sliced NDArray that shares memory with current one.
@@ -234,36 +244,36 @@ class NDArray private[mxnet] (val nd : org.apache.mxnet.NDArray ) {
     *
     * @return a sliced NDArray that shares memory with current one.
     */
-  def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop)
+  def slice(start: Int, stop: Int): NDArray = nd.slice(start, stop)
 
   /**
     * Return a sliced NDArray at the ith position of axis0
     * @param i
     * @return a sliced NDArray that shares memory with current one.
     */
-  def slice (i : Int) : NDArray = nd.slice(i)
+  def slice (i: Int): NDArray = nd.slice(i)
 
   /**
     * Return a sub NDArray that shares memory with current one.
     * the first axis will be rolled up, which causes its shape different from slice(i, i+1)
     * @param idx index of sub array.
     */
-  def at(idx : Int) : NDArray = nd.at(idx)
+  def at(idx: Int): NDArray = nd.at(idx)
 
-  def T : NDArray = nd.T
+  def T: NDArray = nd.T
 
   /**
     * Get data type of current NDArray.
     * @return class representing type of current ndarray
     */
-  def dtype : DType = nd.dtype
+  def dtype: DType = nd.dtype
 
   /**
     * Return a copied numpy array of current array with specified type.
     * @param dtype Desired type of result array.
     * @return A copy of array content.
     */
-  def asType(dtype : DType) : NDArray = nd.asType(dtype)
+  def asType(dtype: DType): NDArray = nd.asType(dtype)
 
   /**
     * Return a reshaped NDArray that shares memory with current one.
@@ -271,7 +281,7 @@ class NDArray private[mxnet] (val nd : org.apache.mxnet.NDArray ) {
     *
     * @return a reshaped NDArray that shares memory with current one.
     */
-  def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims)
+  def reshape(dims: Array[Int]): NDArray = nd.reshape(dims)
 
   /**
     * Block until all pending writes operations on current NDArray are finished.
@@ -285,55 +295,55 @@ class NDArray private[mxnet] (val nd : org.apache.mxnet.NDArray ) {
     * Get context of current NDArray.
     * @return The context of current NDArray.
     */
-  def context : Context = nd.context
+  def context: Context = nd.context
 
   /**
     * Set the values of the NDArray
     * @param value Value to set
     * @return Current NDArray
     */
-  def set(value : Float) : NDArray = nd.set(value)
-  def set(other : NDArray) : NDArray = nd.set(other)
-  def set(other : Array[Float]) : NDArray = nd.set(other)
-
-  def add(other : NDArray) : NDArray = this.nd + other.nd
-  def add(other : Float) : NDArray = this.nd + other
-  def _add(other : NDArray) : NDArray = this.nd += other
-  def _add(other : Float) : NDArray = this.nd += other
-  def subtract(other : NDArray) : NDArray = this.nd - other
-  def subtract(other : Float) : NDArray = this.nd - other
-  def _subtract(other : NDArray) : NDArray = this.nd -= other
-  def _subtract(other : Float) : NDArray = this.nd -= other
-  def multiply(other : NDArray) : NDArray = this.nd * other
-  def multiply(other : Float) : NDArray = this.nd * other
-  def _multiply(other : NDArray) : NDArray = this.nd *= other
-  def _multiply(other : Float) : NDArray = this.nd *= other
-  def div(other : NDArray) : NDArray = this.nd / other
-  def div(other : Float) : NDArray = this.nd / other
-  def _div(other : NDArray) : NDArray = this.nd /= other
-  def _div(other : Float) : NDArray = this.nd /= other
-  def pow(other : NDArray) : NDArray = this.nd ** other
-  def pow(other : Float) : NDArray = this.nd ** other
-  def _pow(other : NDArray) : NDArray = this.nd **= other
-  def _pow(other : Float) : NDArray = this.nd **= other
-  def mod(other : NDArray) : NDArray = this.nd % other
-  def mod(other : Float) : NDArray = this.nd % other
-  def _mod(other : NDArray) : NDArray = this.nd %= other
-  def _mod(other : Float) : NDArray = this.nd %= other
-  def greater(other : NDArray) : NDArray = this.nd > other
-  def greater(other : Float) : NDArray = this.nd > other
-  def greaterEqual(other : NDArray) : NDArray = this.nd >= other
-  def greaterEqual(other : Float) : NDArray = this.nd >= other
-  def lesser(other : NDArray) : NDArray = this.nd < other
-  def lesser(other : Float) : NDArray = this.nd < other
-  def lesserEqual(other : NDArray) : NDArray = this.nd <= other
-  def lesserEqual(other : Float) : NDArray = this.nd <= other
+  def set(value: Float): NDArray = nd.set(value)
+  def set(other: NDArray): NDArray = nd.set(other)
+  def set(other: Array[Float]): NDArray = nd.set(other)
+
+  def add(other: NDArray): NDArray = this.nd + other.nd
+  def add(other: Float): NDArray = this.nd + other
+  def addInplace(other: NDArray): NDArray = this.nd += other
+  def addInplace(other: Float): NDArray = this.nd += other
+  def subtract(other: NDArray): NDArray = this.nd - other
+  def subtract(other: Float): NDArray = this.nd - other
+  def subtractInplace(other: NDArray): NDArray = this.nd -= other
+  def subtractInplace(other: Float): NDArray = this.nd -= other
+  def multiply(other: NDArray): NDArray = this.nd * other
+  def multiply(other: Float): NDArray = this.nd * other
+  def multiplyInplace(other: NDArray): NDArray = this.nd *= other
+  def multiplyInplace(other: Float): NDArray = this.nd *= other
+  def div(other: NDArray): NDArray = this.nd / other
+  def div(other: Float): NDArray = this.nd / other
+  def divInplace(other: NDArray): NDArray = this.nd /= other
+  def divInplace(other: Float): NDArray = this.nd /= other
+  def pow(other: NDArray): NDArray = this.nd ** other
+  def pow(other: Float): NDArray = this.nd ** other
+  def powInplace(other: NDArray): NDArray = this.nd **= other
+  def powInplace(other: Float): NDArray = this.nd **= other
+  def mod(other: NDArray): NDArray = this.nd % other
+  def mod(other: Float): NDArray = this.nd % other
+  def modInplace(other: NDArray): NDArray = this.nd %= other
+  def modInplace(other: Float): NDArray = this.nd %= other
+  def greater(other: NDArray): NDArray = this.nd > other
+  def greater(other: Float): NDArray = this.nd > other
+  def greaterEqual(other: NDArray): NDArray = this.nd >= other
+  def greaterEqual(other: Float): NDArray = this.nd >= other
+  def lesser(other: NDArray): NDArray = this.nd < other
+  def lesser(other: Float): NDArray = this.nd < other
+  def lesserEqual(other: NDArray): NDArray = this.nd <= other
+  def lesserEqual(other: Float): NDArray = this.nd <= other
 
   /**
     * Return a copied flat java array of current array (row-major).
     * @return  A copy of array content.
     */
-  def toArray : Array[Float] = nd.toArray
+  def toArray: Array[Float] = nd.toArray
 
   /**
     * Return a CPU scalar(float) of current ndarray.
@@ -341,7 +351,7 @@ class NDArray private[mxnet] (val nd : org.apache.mxnet.NDArray ) {
     *
     * @return The scalar representation of the ndarray.
     */
-  def toScalar : Float = nd.toScalar
+  def toScalar: Float = nd.toScalar
 
   /**
     * Copy the content of current array to other.
@@ -349,7 +359,7 @@ class NDArray private[mxnet] (val nd : org.apache.mxnet.NDArray ) {
     * @param other Target NDArray or context we want to copy data to.
     * @return The copy target NDArray
     */
-  def copyTo(other : NDArray) : NDArray = nd.copyTo(other)
+  def copyTo(other: NDArray): NDArray = nd.copyTo(other)
 
   /**
     * Copy the content of current array to a new NDArray in the context.
@@ -357,22 +367,22 @@ class NDArray private[mxnet] (val nd : org.apache.mxnet.NDArray ) {
     * @param ctx Target context we want to copy data to.
     * @return The copy target NDArray
     */
-  def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx)
+  def copyTo(ctx: Context): NDArray = nd.copyTo(ctx)
 
   /**
     * Clone the current array
     * @return the copied NDArray in the same context
     */
-  def copy() : NDArray = copyTo(this.context)
+  def copy(): NDArray = copyTo(this.context)
 
   /**
     * Get shape of current NDArray.
     * @return an array representing shape of current ndarray
     */
-  def shape : Shape = nd.shape
+  def shape: Shape = nd.shape
 
 
-  def size : Int = shape.product
+  def size: Int = shape.product
 
   /**
     * Return an `NDArray` that lives in the target context. If the array


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services