You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/01/25 20:01:22 UTC

[incubator-mxnet] branch master updated: [MXNET-1000] get Ndarray real value and form it from a NDArray (#12690)

This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0f334ae  [MXNET-1000] get Ndarray real value and form it from a NDArray (#12690)
0f334ae is described below

commit 0f334aecf569c2f0ed5a279798e0ca58c4d143dc
Author: Lanking <la...@live.com>
AuthorDate: Fri Jan 25 12:01:02 2019 -0800

    [MXNET-1000] get Ndarray real value and form it from a NDArray (#12690)
    
    * add visualize
    
    * adding Any type input to form NDArray
    
    * fix bug and add tests
    
    * add a toString method
    
    * add Visualize Util and migrate visualize structure to there
    
    * update with tests
    
    * refactor code
    
    * fix the minor issue
    
    * add multiple types support
    
    * add changes on names and tests
    
    * make code elegant and improve readability
---
 .../scala/org/apache/mxnet/MX_PRIMITIVES.scala     |   6 ++
 .../src/main/scala/org/apache/mxnet/NDArray.scala  | 112 ++++++++++++++++++++-
 .../test/scala/org/apache/mxnet/NDArraySuite.scala |  82 +++++++++++++++
 3 files changed, 199 insertions(+), 1 deletion(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
index cb97885..3a51222 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
@@ -82,4 +82,10 @@ object MX_PRIMITIVES {
 
   implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data
 
+  def isValidMxPrimitiveType(num : Any) : Boolean = {
+    num match {
+      case valid @ (_: Float | _: Double) => true
+      case _ => false
+    }
+  }
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 163ed26..5c345f2 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -28,6 +28,7 @@ import scala.collection.mutable
 import scala.collection.mutable.{ArrayBuffer, ListBuffer}
 import scala.language.implicitConversions
 import scala.ref.WeakReference
+import scala.util.Try
 
 /**
   * NDArray Object extends from NDArrayBase for abstract function signatures
@@ -510,6 +511,61 @@ object NDArray extends NDArrayBase {
   }
 
   /**
+    * Create a new NDArray based on the structure of source Array
+    * @param sourceArr Array[Array...Array[MX_PRIMITIVE_TYPE]...]
+    * @param ctx context like to pass in
+    * @return an NDArray with the same shape of the input
+    * @throws IllegalArgumentException if the data type is not valid
+    */
+  def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = {
+    val shape = shapeGetter(sourceArr)
+    val container = new Array[Any](shape.product)
+    flattenArray(sourceArr, container, 0, container.length - 1)
+    val finalArr = container(0) match {
+      case f: Float => array(container.map(_.asInstanceOf[Float]), Shape(shape), ctx)
+      case d: Double => array(container.map(_.asInstanceOf[Double]), Shape(shape), ctx)
+      case _ => throw new IllegalArgumentException(
+        s"Unsupported type ${container(0).getClass}, please check MX_PRIMITIVES for valid types")
+    }
+    finalArr
+  }
+
+  private def shapeGetter(sourceArr : Any) : ArrayBuffer[Int] = {
+    sourceArr match {
+        // e.g : Array[Double] the inner layer
+      case arr: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arr(0)) => {
+        ArrayBuffer[Int](arr.length)
+      }
+        // e.g : Array[Array...[]]
+      case arr: Array[_] => {
+        var arrBuffer = new ArrayBuffer[Int]()
+        if (!arr.isEmpty) arrBuffer = shapeGetter(arr(0))
+        for (idx <- arr.indices) {
+          require(arrBuffer == shapeGetter(arr(idx)))
+        }
+        arrBuffer.insert(0, arr.length)
+        arrBuffer
+      }
+      case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}")
+    }
+  }
+
+  private def flattenArray(sourceArr : Any, arr : Array[Any],
+                            start : Int, end : Int) : Unit = {
+    sourceArr match {
+      case arrValid: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arrValid(0)) => {
+        for (i <- arrValid.indices) arr(start + i) = arrValid(i)
+      }
+      case arrAny: Array[_] => {
+        val fragment = (end - start + 1) / arrAny.length
+        for (i <- arrAny.indices)
+          flattenArray(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment)
+      }
+      case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}")
+    }
+  }
+
+  /**
    * Returns evenly spaced values within a given interval.
    * Values are generated within the half-open interval [`start`, `stop`). In other
    * words, the interval includes `start` but excludes `stop`.
@@ -667,7 +723,6 @@ object NDArray extends NDArrayBase {
     genericNDArrayFunctionInvoke("_crop_assign", args, kwargs)
   }
 
-  // TODO: imdecode
 }
 
 /**
@@ -694,6 +749,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
   // we use weak reference to prevent gc blocking
   private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
 
+  private val lengthProperty = "mxnet.setNDArrayPrintLength"
+  private val layerProperty = "mxnet.setNDArrayPrintLayerLength"
+  private lazy val printLength = Try(System.getProperty(lengthProperty).toInt).getOrElse(1000)
+  private lazy val layerLength = Try(System.getProperty(layerProperty).toInt).getOrElse(10)
+
   def serialize(): Array[Byte] = {
     val buf = ArrayBuffer.empty[Byte]
     checkCall(_LIB.mxNDArraySaveRawBytes(handle, buf))
@@ -764,6 +824,56 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
   }
 
   /**
+    * Visualize the internal structure of NDArray
+    * @return String that show the structure
+    */
+  override def toString: String = {
+    val abstractND = buildStringHelper(this, this.shape.length)
+    val otherInfo = s"<NDArray ${this.shape} ${this.context} ${this.dtype}>"
+    s"$abstractND\n$otherInfo"
+  }
+
+  /**
+    * Helper function to create formatted NDArray output
+    * The NDArray will be represented in a reduced version if too large
+    * @param nd NDArray as the input
+    * @param totalSpace totalSpace of the lowest dimension
+    * @return String format of NDArray
+    */
+  private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = {
+    var result = ""
+    val THRESHOLD = layerLength        // longest NDArray[NDArray[...]] to show in full
+    val ARRAYTHRESHOLD = printLength   // longest array to show in full
+    val shape = nd.shape
+    val space = totalSpace - shape.length
+    if (shape.length != 1) {
+      val (length, postfix) =
+        if (shape(0) > THRESHOLD) {
+          // reduced NDArray
+          (10, s"\n${" " * (space + 1)}... with length ${shape(0)}\n")
+        } else {
+          (shape(0), "")
+        }
+      for (num <- 0 until length) {
+        val output = buildStringHelper(nd.at(num), totalSpace)
+        result += s"$output\n"
+      }
+      result = s"${" " * space}[\n$result${" " * space}$postfix${" " * space}]"
+    } else {
+      if (shape(0) > ARRAYTHRESHOLD) {
+        // reduced Array
+        val front = nd.slice(0, 10)
+        val back = nd.slice(shape(0) - 10, shape(0) - 1)
+        result = s"""${" " * space}[${front.toArray.mkString(",")}
+             | ... ${back.toArray.mkString(",")}]""".stripMargin
+      } else {
+        result = s"${" " * space}[${nd.toArray.mkString(",")}]"
+      }
+    }
+    result
+  }
+
+  /**
    * Return a sliced NDArray that shares memory with current one.
    * NDArray only support continuous slicing on axis 0
    *
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index bc7a0a0..054300e 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -22,10 +22,14 @@ import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.mxnet.NDArrayConversions._
 import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+import org.slf4j.LoggerFactory
+import scala.collection.mutable.ArrayBuffer
 
 class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   private val sequence: AtomicInteger = new AtomicInteger(0)
 
+  private val logger = LoggerFactory.getLogger(classOf[NDArraySuite])
+
   test("to java array") {
     val ndarray = NDArray.zeros(2, 2)
     assert(ndarray.toArray === Array(0f, 0f, 0f, 0f))
@@ -85,6 +89,84 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     assert(ndarray.toArray === Array(1f, 2f, 3f, 4f))
   }
 
+  test("create NDArray based on Java Matrix") {
+    def arrayGen(num : Any) : Array[Any] = {
+      val array = num match {
+        case f: Float =>
+          (for (_ <- 0 until 100) yield Array(1.0f, 1.0f, 1.0f, 1.0f)).toArray
+        case d: Double =>
+          (for (_ <- 0 until 100) yield Array(1.0d, 1.0d, 1.0d, 1.0d)).toArray
+        case _ => throw new IllegalArgumentException(s"Unsupported Type ${num.getClass}")
+      }
+      Array(
+        Array(
+          array
+        ),
+        Array(
+          array
+        )
+      )
+    }
+    val floatData = 1.0f
+    var nd = NDArray.toNDArray(arrayGen(floatData))
+    require(nd.shape == Shape(2, 1, 100, 4))
+    val arr2 = Array(1.0f, 1.0f, 1.0f, 1.0f)
+    nd = NDArray.toNDArray(arr2)
+    require(nd.shape == Shape(4))
+    val doubleData = 1.0d
+    nd = NDArray.toNDArray(arrayGen(doubleData))
+    require(nd.shape == Shape(2, 1, 100, 4))
+    require(nd.dtype == DType.Float64)
+  }
+
+  test("test Visualize") {
+    var nd = NDArray.ones(Shape(1, 2, 1000, 1))
+    var data : String =
+      """
+        |[
+        | [
+        |  [
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |
+        |   ... with length 1000
+        |  ]
+        |  [
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |   [1.0]
+        |
+        |   ... with length 1000
+        |  ]
+        |  ]
+        |]
+        |<NDArray (1,2,1000,1) cpu(0) float32>""".stripMargin
+    require(nd.toString.split("\\s+").mkString == data.split("\\s+").mkString)
+    nd = NDArray.ones(Shape(1, 4))
+    data =
+      """
+        |[
+        | [1.0,1.0,1.0,1.0]
+        |]
+        |<NDArray (1,4) cpu(0) float32>""".stripMargin
+    require(nd.toString.split("\\s+").mkString == data.split("\\s+").mkString)
+  }
+
   test("plus") {
     var ndzeros = NDArray.zeros(2, 1)
     var ndones = ndzeros + 1f