You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/02/01 01:38:45 UTC
[incubator-mxnet] branch master updated: Now passing DType of Label
downstream to Label's DataDesc object (#14038)
This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f95e794 Now passing DType of Label downstream to Label's DataDesc object (#14038)
f95e794 is described below
commit f95e7949dcd96ca2a5a140dbcff16dd344b45d19
Author: Piyush Ghai <gh...@osu.edu>
AuthorDate: Thu Jan 31 17:38:26 2019 -0800
Now passing DType of Label downstream to Label's DataDesc object (#14038)
---
.../core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala | 6 ++++--
scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala | 3 ++-
2 files changed, 6 insertions(+), 3 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index e690abb..b205bbe 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -63,7 +63,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
dataName: String = "data", labelName: String = "label") {
this(IO.initDataDesc(data, allowEmpty = false, dataName,
if (data == null || data.isEmpty) MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED),
- IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED),
+ IO.initDataDesc(label, allowEmpty = true, labelName,
+ if (label == null || label.isEmpty) MX_REAL_TYPE else label(0).dtype, Layout.UNDEFINED),
dataBatchSize, shuffle, lastBatchHandle)
}
@@ -175,7 +176,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
private def _padData(ndArray: NDArray): NDArray = {
val padNum = cursor + dataBatchSize - numData
val shape = Shape(dataBatchSize) ++ ndArray.shape.slice(1, ndArray.shape.size)
- val newArray = NDArray.zeros(shape)
+ // The new NDArray has to be created such that it inherits dtype from the passed in array
+ val newArray = NDArray.zeros(shape, dtype = ndArray.dtype)
NDArrayCollector.auto().withScope {
val batch = ndArray.slice(cursor, numData)
val padding = ndArray.slice(0, padNum)
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index d3969b0..698a2b5 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -237,7 +237,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
val shape0 = Shape(Array(1000, 2, 2))
val data = IndexedSeq(NDArray.ones(shape0), NDArray.zeros(shape0))
val shape1 = Shape(Array(1000, 1))
- val label = IndexedSeq(NDArray.ones(shape1))
+ val label = IndexedSeq(NDArray.ones(shape1, dtype = DType.Int32))
val batchData0 = NDArray.ones(Shape(Array(128, 2, 2)))
val batchData1 = NDArray.zeros(Shape(Array(128, 2, 2)))
val batchLabel = NDArray.ones(Shape(Array(128, 1)))
@@ -254,6 +254,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(tBatch.data(0).toArray === batchData0.toArray)
assert(tBatch.data(1).toArray === batchData1.toArray)
assert(tBatch.label(0).toArray === batchLabel.toArray)
+ assert(tBatch.label(0).dtype == DType.Int32)
}
assert(batchCount === nBatch0)