You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/02/01 01:38:45 UTC

[incubator-mxnet] branch master updated: Now passing DType of Label downstream to Label's DataDesc object (#14038)

This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f95e794  Now passing DType of Label downstream to Label's DataDesc object (#14038)
f95e794 is described below

commit f95e7949dcd96ca2a5a140dbcff16dd344b45d19
Author: Piyush Ghai <gh...@osu.edu>
AuthorDate: Thu Jan 31 17:38:26 2019 -0800

    Now passing DType of Label downstream to Label's DataDesc object (#14038)
---
 .../core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala       | 6 ++++--
 scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala    | 3 ++-
 2 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index e690abb..b205bbe 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -63,7 +63,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
            dataName: String = "data", labelName: String = "label") {
     this(IO.initDataDesc(data, allowEmpty = false, dataName,
       if (data == null || data.isEmpty)  MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED),
-      IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED),
+      IO.initDataDesc(label, allowEmpty = true, labelName,
+        if (label == null || label.isEmpty)  MX_REAL_TYPE else label(0).dtype, Layout.UNDEFINED),
       dataBatchSize, shuffle, lastBatchHandle)
   }
 
@@ -175,7 +176,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
   private def _padData(ndArray: NDArray): NDArray = {
     val padNum = cursor + dataBatchSize - numData
     val shape = Shape(dataBatchSize) ++ ndArray.shape.slice(1, ndArray.shape.size)
-    val newArray = NDArray.zeros(shape)
+    // The new NDArray  has to be created such that it inherits dtype from the passed in array
+    val newArray = NDArray.zeros(shape, dtype = ndArray.dtype)
     NDArrayCollector.auto().withScope {
       val batch = ndArray.slice(cursor, numData)
       val padding = ndArray.slice(0, padNum)
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index d3969b0..698a2b5 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -237,7 +237,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     val shape0 = Shape(Array(1000, 2, 2))
     val data = IndexedSeq(NDArray.ones(shape0), NDArray.zeros(shape0))
     val shape1 = Shape(Array(1000, 1))
-    val label = IndexedSeq(NDArray.ones(shape1))
+    val label = IndexedSeq(NDArray.ones(shape1, dtype = DType.Int32))
     val batchData0 = NDArray.ones(Shape(Array(128, 2, 2)))
     val batchData1 = NDArray.zeros(Shape(Array(128, 2, 2)))
     val batchLabel = NDArray.ones(Shape(Array(128, 1)))
@@ -254,6 +254,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
       assert(tBatch.data(0).toArray === batchData0.toArray)
       assert(tBatch.data(1).toArray === batchData1.toArray)
       assert(tBatch.label(0).toArray === batchLabel.toArray)
+      assert(tBatch.label(0).dtype == DType.Int32)
     }
 
     assert(batchCount === nBatch0)