You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/06/15 23:50:11 UTC

[incubator-mxnet] branch master updated: [MXNET-539] Allow Scala users to specify data/label names for NDArrayIter (#11256)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 02e8a71  [MXNET-539] Allow Scala users to specify data/label names for NDArrayIter (#11256)
02e8a71 is described below

commit 02e8a7199fad36b1b4b3da58b44310e5808c7378
Author: Yizhi Liu <li...@apache.org>
AuthorDate: Fri Jun 15 16:50:01 2018 -0700

    [MXNET-539] Allow Scala users to specify data/label names for NDArrayIter (#11256)
    
    * improve NDArrayIter to have Builder and ability to specifying names
---
 .../scala/org/apache/mxnet/io/NDArrayIter.scala    | 134 ++++++++++++++++-----
 .../src/test/scala/org/apache/mxnet/IOSuite.scala  |   8 +-
 2 files changed, 113 insertions(+), 29 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 5108938..70c6487 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -28,7 +28,8 @@ import scala.collection.immutable.ListMap
 /**
  * NDArrayIter object in mxnet. Taking NDArray to get dataiter.
  *
- * @param data NDArrayIter supports single or multiple data and label.
+ * @param data Specify the data as well as the name.
+ *             NDArrayIter supports single or multiple data and label.
  * @param label Same as data, but is not fed to the model during testing.
  * @param dataBatchSize Batch Size
  * @param shuffle Whether to shuffle the data
@@ -38,15 +39,35 @@ import scala.collection.immutable.ListMap
  * the size of data does not match batch_size. Roll over is intended
  * for training and can cause problems if used for prediction.
  */
-class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
-                  private val dataBatchSize: Int = 1, shuffle: Boolean = false,
-                  lastBatchHandle: String = "pad",
-                  dataName: String = "data", labelName: String = "label") extends DataIter {
-  private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
+class NDArrayIter(data: IndexedSeq[(String, NDArray)],
+                  label: IndexedSeq[(String, NDArray)],
+                  private val dataBatchSize: Int, shuffle: Boolean,
+                  lastBatchHandle: String) extends DataIter {
+
+  /**
+   * @param data Specify the data. Data names will be data_0, data_1, ..., etc.
+   * @param label Same as data, but is not fed to the model during testing.
+   *              Label names will be label_0, label_1, ..., etc.
+   * @param dataBatchSize Batch Size
+   * @param shuffle Whether to shuffle the data
+   * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
+   *
+   * This iterator will pad, discard or roll over the last batch if
+   * the size of data does not match batch_size. Roll over is intended
+   * for training and can cause problems if used for prediction.
+   */
+  def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
+           dataBatchSize: Int = 1, shuffle: Boolean = false,
+           lastBatchHandle: String = "pad",
+           dataName: String = "data", labelName: String = "label") {
+    this(IO.initData(data, allowEmpty = false, dataName),
+      IO.initData(label, allowEmpty = true, labelName),
+      dataBatchSize, shuffle, lastBatchHandle)
+  }
 
+  private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
 
-  private val (_dataList: IndexedSeq[NDArray],
-  _labelList: IndexedSeq[NDArray]) = {
+  val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = {
     // data should not be null and size > 0
     require(data != null && data.size > 0,
       "data should not be null and data.size should not be zero")
@@ -55,17 +76,17 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
       "label should not be null. Use IndexedSeq.empty if there are no labels")
 
     // shuffle is not supported currently
-    require(shuffle == false, "shuffle is not supported currently")
+    require(!shuffle, "shuffle is not supported currently")
 
     // discard final part if lastBatchHandle equals discard
     if (lastBatchHandle.equals("discard")) {
-      val dataSize = data(0).shape(0)
+      val dataSize = data(0)._2.shape(0)
       require(dataBatchSize <= dataSize,
         "batch_size need to be smaller than data size when not padding.")
       val keepSize = dataSize - dataSize % dataBatchSize
-      val dataList = data.map(ndArray => {ndArray.slice(0, keepSize)})
+      val dataList = data.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) }
       if (!label.isEmpty) {
-        val labelList = label.map(ndArray => {ndArray.slice(0, keepSize)})
+        val labelList = label.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) }
         (dataList, labelList)
       } else {
         (dataList, label)
@@ -75,13 +96,9 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
     }
   }
 
-
-  val initData: IndexedSeq[(String, NDArray)] = IO.initData(_dataList, false, dataName)
-  val initLabel: IndexedSeq[(String, NDArray)] = IO.initData(_labelList, true, labelName)
-  val numData = _dataList(0).shape(0)
-  val numSource = initData.size
-  var cursor = -dataBatchSize
-
+  val numData = initData(0)._2.shape(0)
+  val numSource: MXUint = initData.size
+  private var cursor = -dataBatchSize
 
   private val (_provideData: ListMap[String, Shape],
                _provideLabel: ListMap[String, Shape]) = {
@@ -112,8 +129,8 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
    * reset the iterator
    */
   override def reset(): Unit = {
-    if (lastBatchHandle.equals("roll_over") && cursor>numData) {
-      cursor = -dataBatchSize + (cursor%numData)%dataBatchSize
+    if (lastBatchHandle.equals("roll_over") && cursor > numData) {
+      cursor = -dataBatchSize + (cursor%numData) % dataBatchSize
     } else {
       cursor = -dataBatchSize
     }
@@ -154,16 +171,16 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
     newArray
   }
 
-  private def _getData(data: IndexedSeq[NDArray]): IndexedSeq[NDArray] = {
+  private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = {
     require(cursor < numData, "DataIter needs reset.")
     if (data == null) {
       null
     } else {
       if (cursor + dataBatchSize <= numData) {
-        data.map(ndArray => {ndArray.slice(cursor, cursor + dataBatchSize)}).toIndexedSeq
+        data.map { case (_, ndArray) => ndArray.slice(cursor, cursor + dataBatchSize) }
       } else {
         // padding
-        data.map(_padData).toIndexedSeq
+        data.map { case (_, ndArray) => _padData(ndArray) }
       }
     }
   }
@@ -173,7 +190,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
    * @return the data of current batch
    */
   override def getData(): IndexedSeq[NDArray] = {
-    _getData(_dataList)
+    _getData(initData)
   }
 
   /**
@@ -181,7 +198,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
    * @return the label of current batch
    */
   override def getLabel(): IndexedSeq[NDArray] = {
-    _getData(_labelList)
+    _getData(initLabel)
   }
 
   /**
@@ -189,7 +206,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
    * @return
    */
   override def getIndex(): IndexedSeq[Long] = {
-    (cursor.toLong to (cursor + dataBatchSize).toLong).toIndexedSeq
+    cursor.toLong to (cursor + dataBatchSize).toLong
   }
 
   /**
@@ -213,3 +230,66 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
 
   override def batchSize: Int = dataBatchSize
 }
+
+object NDArrayIter {
+
+  /**
+   * Builder class for NDArrayIter.
+   */
+  class Builder() {
+    private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
+    private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
+    private var dataBatchSize: Int = 1
+    private var lastBatchHandle: String = "pad"
+
+    /**
+     * Add one data input with its name.
+     * @param name Data name.
+     * @param data Data nd-array.
+     * @return The builder object itself.
+     */
+    def addData(name: String, data: NDArray): Builder = {
+      this.data = this.data ++ IndexedSeq((name, data))
+      this
+    }
+
+    /**
+     * Add one label input with its name.
+     * @param name Label name.
+     * @param label Label nd-array.
+     * @return The builder object itself.
+     */
+    def addLabel(name: String, label: NDArray): Builder = {
+      this.label = this.label ++ IndexedSeq((name, label))
+      this
+    }
+
+    /**
+     * Set the batch size of the iterator.
+     * @param batchSize batch size.
+     * @return The builder object itself.
+     */
+    def setBatchSize(batchSize: Int): Builder = {
+      this.dataBatchSize = batchSize
+      this
+    }
+
+    /**
+     * How to handle the last batch.
+     * @param lastBatchHandle Can be "pad", "discard" or "roll_over".
+     * @return The builder object itself.
+     */
+    def setLastBatchHandle(lastBatchHandle: String): Builder = {
+      this.lastBatchHandle = lastBatchHandle
+      this
+    }
+
+    /**
+     * Build the NDArrayIter object.
+     * @return the built object.
+     */
+    def build(): NDArrayIter = {
+      new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle)
+    }
+  }
+}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 0f4b7c0..1b922b3 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -24,7 +24,7 @@ import scala.sys.process._
 
 class IOSuite extends FunSuite with BeforeAndAfterAll {
 
-  private var tu = new TestUtil
+  private val tu = new TestUtil
 
   test("test MNISTIter & MNISTPack") {
     // get data
@@ -258,7 +258,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     assert(batchCount === nBatch0)
 
     // test discard
-    val dataIter1 = new NDArrayIter(data, label, 128, false, "discard")
+    val dataIter1 = new NDArrayIter.Builder()
+      .addData("data0", data(0)).addData("data1", data(1))
+      .addLabel("label", label(0))
+      .setBatchSize(128)
+      .setLastBatchHandle("discard").build()
     val nBatch1 = 7
     batchCount = 0
     while(dataIter1.hasNext) {

-- 
To stop receiving notification emails like this one, please contact
nswamy@apache.org.