You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/01/25 20:01:22 UTC
[incubator-mxnet] branch master updated: [MXNET-1000] get Ndarray
real value and form it from a NDArray (#12690)
This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0f334ae [MXNET-1000] get Ndarray real value and form it from a NDArray (#12690)
0f334ae is described below
commit 0f334aecf569c2f0ed5a279798e0ca58c4d143dc
Author: Lanking <la...@live.com>
AuthorDate: Fri Jan 25 12:01:02 2019 -0800
[MXNET-1000] get Ndarray real value and form it from a NDArray (#12690)
* add visualize
* adding Any type input to form NDArray
* fix bug and add tests
* add a toString method
* add Visualize Util and migrate visualize structure to there
* update with tests
* refactor code
* fix the minor issue
* add multiple types support
* add changes on names and tests
* make code elegant and improve readability
---
.../scala/org/apache/mxnet/MX_PRIMITIVES.scala | 6 ++
.../src/main/scala/org/apache/mxnet/NDArray.scala | 112 ++++++++++++++++++++-
.../test/scala/org/apache/mxnet/NDArraySuite.scala | 82 +++++++++++++++
3 files changed, 199 insertions(+), 1 deletion(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
index cb97885..3a51222 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
@@ -82,4 +82,10 @@ object MX_PRIMITIVES {
implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data
+ def isValidMxPrimitiveType(num : Any) : Boolean = {
+ num match {
+ case valid @ (_: Float | _: Double) => true
+ case _ => false
+ }
+ }
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 163ed26..5c345f2 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -28,6 +28,7 @@ import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.language.implicitConversions
import scala.ref.WeakReference
+import scala.util.Try
/**
* NDArray Object extends from NDArrayBase for abstract function signatures
@@ -510,6 +511,61 @@ object NDArray extends NDArrayBase {
}
/**
+ * Create a new NDArray based on the structure of source Array
+ * @param sourceArr Array[Array...Array[MX_PRIMITIVE_TYPE]...]
+ * @param ctx context like to pass in
+ * @return an NDArray with the same shape of the input
+ * @throws IllegalArgumentException if the data type is not valid
+ */
+ def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = {
+ val shape = shapeGetter(sourceArr)
+ val container = new Array[Any](shape.product)
+ flattenArray(sourceArr, container, 0, container.length - 1)
+ val finalArr = container(0) match {
+ case f: Float => array(container.map(_.asInstanceOf[Float]), Shape(shape), ctx)
+ case d: Double => array(container.map(_.asInstanceOf[Double]), Shape(shape), ctx)
+ case _ => throw new IllegalArgumentException(
+ s"Unsupported type ${container(0).getClass}, please check MX_PRIMITIVES for valid types")
+ }
+ finalArr
+ }
+
+ private def shapeGetter(sourceArr : Any) : ArrayBuffer[Int] = {
+ sourceArr match {
+ // e.g : Array[Double] the inner layer
+ case arr: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arr(0)) => {
+ ArrayBuffer[Int](arr.length)
+ }
+ // e.g : Array[Array...[]]
+ case arr: Array[_] => {
+ var arrBuffer = new ArrayBuffer[Int]()
+ if (!arr.isEmpty) arrBuffer = shapeGetter(arr(0))
+ for (idx <- arr.indices) {
+ require(arrBuffer == shapeGetter(arr(idx)))
+ }
+ arrBuffer.insert(0, arr.length)
+ arrBuffer
+ }
+ case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}")
+ }
+ }
+
+ private def flattenArray(sourceArr : Any, arr : Array[Any],
+ start : Int, end : Int) : Unit = {
+ sourceArr match {
+ case arrValid: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arrValid(0)) => {
+ for (i <- arrValid.indices) arr(start + i) = arrValid(i)
+ }
+ case arrAny: Array[_] => {
+ val fragment = (end - start + 1) / arrAny.length
+ for (i <- arrAny.indices)
+ flattenArray(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment)
+ }
+ case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}")
+ }
+ }
+
+ /**
* Returns evenly spaced values within a given interval.
* Values are generated within the half-open interval [`start`, `stop`). In other
* words, the interval includes `start` but excludes `stop`.
@@ -667,7 +723,6 @@ object NDArray extends NDArrayBase {
genericNDArrayFunctionInvoke("_crop_assign", args, kwargs)
}
- // TODO: imdecode
}
/**
@@ -694,6 +749,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
+ private val lengthProperty = "mxnet.setNDArrayPrintLength"
+ private val layerProperty = "mxnet.setNDArrayPrintLayerLength"
+ private lazy val printLength = Try(System.getProperty(lengthProperty).toInt).getOrElse(1000)
+ private lazy val layerLength = Try(System.getProperty(layerProperty).toInt).getOrElse(10)
+
def serialize(): Array[Byte] = {
val buf = ArrayBuffer.empty[Byte]
checkCall(_LIB.mxNDArraySaveRawBytes(handle, buf))
@@ -764,6 +824,56 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
}
/**
+ * Visualize the internal structure of NDArray
+ * @return String that show the structure
+ */
+ override def toString: String = {
+ val abstractND = buildStringHelper(this, this.shape.length)
+ val otherInfo = s"<NDArray ${this.shape} ${this.context} ${this.dtype}>"
+ s"$abstractND\n$otherInfo"
+ }
+
+ /**
+ * Helper function to create formatted NDArray output
+ * The NDArray will be represented in a reduced version if too large
+ * @param nd NDArray as the input
+ * @param totalSpace totalSpace of the lowest dimension
+ * @return String format of NDArray
+ */
+ private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = {
+ var result = ""
+ val THRESHOLD = layerLength // longest NDArray[NDArray[...]] to show in full
+ val ARRAYTHRESHOLD = printLength // longest array to show in full
+ val shape = nd.shape
+ val space = totalSpace - shape.length
+ if (shape.length != 1) {
+ val (length, postfix) =
+ if (shape(0) > THRESHOLD) {
+ // reduced NDArray
+ (10, s"\n${" " * (space + 1)}... with length ${shape(0)}\n")
+ } else {
+ (shape(0), "")
+ }
+ for (num <- 0 until length) {
+ val output = buildStringHelper(nd.at(num), totalSpace)
+ result += s"$output\n"
+ }
+ result = s"${" " * space}[\n$result${" " * space}$postfix${" " * space}]"
+ } else {
+ if (shape(0) > ARRAYTHRESHOLD) {
+ // reduced Array
+ val front = nd.slice(0, 10)
+ val back = nd.slice(shape(0) - 10, shape(0) - 1)
+ result = s"""${" " * space}[${front.toArray.mkString(",")}
+ | ... ${back.toArray.mkString(",")}]""".stripMargin
+ } else {
+ result = s"${" " * space}[${nd.toArray.mkString(",")}]"
+ }
+ }
+ result
+ }
+
+ /**
* Return a sliced NDArray that shares memory with current one.
* NDArray only support continuous slicing on axis 0
*
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index bc7a0a0..054300e 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -22,10 +22,14 @@ import java.util.concurrent.atomic.AtomicInteger
import org.apache.mxnet.NDArrayConversions._
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+import org.slf4j.LoggerFactory
+import scala.collection.mutable.ArrayBuffer
class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
private val sequence: AtomicInteger = new AtomicInteger(0)
+ private val logger = LoggerFactory.getLogger(classOf[NDArraySuite])
+
test("to java array") {
val ndarray = NDArray.zeros(2, 2)
assert(ndarray.toArray === Array(0f, 0f, 0f, 0f))
@@ -85,6 +89,84 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(ndarray.toArray === Array(1f, 2f, 3f, 4f))
}
+ test("create NDArray based on Java Matrix") {
+ def arrayGen(num : Any) : Array[Any] = {
+ val array = num match {
+ case f: Float =>
+ (for (_ <- 0 until 100) yield Array(1.0f, 1.0f, 1.0f, 1.0f)).toArray
+ case d: Double =>
+ (for (_ <- 0 until 100) yield Array(1.0d, 1.0d, 1.0d, 1.0d)).toArray
+ case _ => throw new IllegalArgumentException(s"Unsupported Type ${num.getClass}")
+ }
+ Array(
+ Array(
+ array
+ ),
+ Array(
+ array
+ )
+ )
+ }
+ val floatData = 1.0f
+ var nd = NDArray.toNDArray(arrayGen(floatData))
+ require(nd.shape == Shape(2, 1, 100, 4))
+ val arr2 = Array(1.0f, 1.0f, 1.0f, 1.0f)
+ nd = NDArray.toNDArray(arr2)
+ require(nd.shape == Shape(4))
+ val doubleData = 1.0d
+ nd = NDArray.toNDArray(arrayGen(doubleData))
+ require(nd.shape == Shape(2, 1, 100, 4))
+ require(nd.dtype == DType.Float64)
+ }
+
+ test("test Visualize") {
+ var nd = NDArray.ones(Shape(1, 2, 1000, 1))
+ var data : String =
+ """
+ |[
+ | [
+ | [
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ |
+ | ... with length 1000
+ | ]
+ | [
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ | [1.0]
+ |
+ | ... with length 1000
+ | ]
+ | ]
+ |]
+ |<NDArray (1,2,1000,1) cpu(0) float32>""".stripMargin
+ require(nd.toString.split("\\s+").mkString == data.split("\\s+").mkString)
+ nd = NDArray.ones(Shape(1, 4))
+ data =
+ """
+ |[
+ | [1.0,1.0,1.0,1.0]
+ |]
+ |<NDArray (1,4) cpu(0) float32>""".stripMargin
+ require(nd.toString.split("\\s+").mkString == data.split("\\s+").mkString)
+ }
+
test("plus") {
var ndzeros = NDArray.zeros(2, 1)
var ndones = ndzeros + 1f