You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/15 23:50:02 UTC

[GitHub] nswamy closed pull request #11256: [MXNET-539] Allow Scala users to specify data/label names for NDArrayIter

nswamy closed pull request #11256: [MXNET-539] Allow Scala users to specify data/label names for NDArrayIter
URL: https://github.com/apache/incubator-mxnet/pull/11256
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 51089382097..70c64877887 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -28,7 +28,8 @@ import scala.collection.immutable.ListMap
 /**
  * NDArrayIter object in mxnet. Taking NDArray to get dataiter.
  *
- * @param data NDArrayIter supports single or multiple data and label.
+ * @param data Specify the data as well as the name.
+ *             NDArrayIter supports single or multiple data and label.
  * @param label Same as data, but is not fed to the model during testing.
  * @param dataBatchSize Batch Size
  * @param shuffle Whether to shuffle the data
@@ -38,15 +39,35 @@ import scala.collection.immutable.ListMap
  * the size of data does not match batch_size. Roll over is intended
  * for training and can cause problems if used for prediction.
  */
-class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
-                  private val dataBatchSize: Int = 1, shuffle: Boolean = false,
-                  lastBatchHandle: String = "pad",
-                  dataName: String = "data", labelName: String = "label") extends DataIter {
-  private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
+class NDArrayIter(data: IndexedSeq[(String, NDArray)],
+                  label: IndexedSeq[(String, NDArray)],
+                  private val dataBatchSize: Int, shuffle: Boolean,
+                  lastBatchHandle: String) extends DataIter {
+
+  /**
+   * @param data Specify the data. Data names will be data_0, data_1, ..., etc.
+   * @param label Same as data, but is not fed to the model during testing.
+   *              Label names will be label_0, label_1, ..., etc.
+   * @param dataBatchSize Batch Size
+   * @param shuffle Whether to shuffle the data
+   * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
+   *
+   * This iterator will pad, discard or roll over the last batch if
+   * the size of data does not match batch_size. Roll over is intended
+   * for training and can cause problems if used for prediction.
+   */
+  def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
+           dataBatchSize: Int = 1, shuffle: Boolean = false,
+           lastBatchHandle: String = "pad",
+           dataName: String = "data", labelName: String = "label") {
+    this(IO.initData(data, allowEmpty = false, dataName),
+      IO.initData(label, allowEmpty = true, labelName),
+      dataBatchSize, shuffle, lastBatchHandle)
+  }
 
+  private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
 
-  private val (_dataList: IndexedSeq[NDArray],
-  _labelList: IndexedSeq[NDArray]) = {
+  val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = {
     // data should not be null and size > 0
     require(data != null && data.size > 0,
       "data should not be null and data.size should not be zero")
@@ -55,17 +76,17 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
       "label should not be null. Use IndexedSeq.empty if there are no labels")
 
     // shuffle is not supported currently
-    require(shuffle == false, "shuffle is not supported currently")
+    require(!shuffle, "shuffle is not supported currently")
 
     // discard final part if lastBatchHandle equals discard
     if (lastBatchHandle.equals("discard")) {
-      val dataSize = data(0).shape(0)
+      val dataSize = data(0)._2.shape(0)
       require(dataBatchSize <= dataSize,
         "batch_size need to be smaller than data size when not padding.")
       val keepSize = dataSize - dataSize % dataBatchSize
-      val dataList = data.map(ndArray => {ndArray.slice(0, keepSize)})
+      val dataList = data.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) }
       if (!label.isEmpty) {
-        val labelList = label.map(ndArray => {ndArray.slice(0, keepSize)})
+        val labelList = label.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) }
         (dataList, labelList)
       } else {
         (dataList, label)
@@ -75,13 +96,9 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
     }
   }
 
-
-  val initData: IndexedSeq[(String, NDArray)] = IO.initData(_dataList, false, dataName)
-  val initLabel: IndexedSeq[(String, NDArray)] = IO.initData(_labelList, true, labelName)
-  val numData = _dataList(0).shape(0)
-  val numSource = initData.size
-  var cursor = -dataBatchSize
-
+  val numData = initData(0)._2.shape(0)
+  val numSource: MXUint = initData.size
+  private var cursor = -dataBatchSize
 
   private val (_provideData: ListMap[String, Shape],
                _provideLabel: ListMap[String, Shape]) = {
@@ -112,8 +129,8 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
    * reset the iterator
    */
   override def reset(): Unit = {
-    if (lastBatchHandle.equals("roll_over") && cursor>numData) {
-      cursor = -dataBatchSize + (cursor%numData)%dataBatchSize
+    if (lastBatchHandle.equals("roll_over") && cursor > numData) {
+      cursor = -dataBatchSize + (cursor%numData) % dataBatchSize
     } else {
       cursor = -dataBatchSize
     }
@@ -154,16 +171,16 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
     newArray
   }
 
-  private def _getData(data: IndexedSeq[NDArray]): IndexedSeq[NDArray] = {
+  private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = {
     require(cursor < numData, "DataIter needs reset.")
     if (data == null) {
       null
     } else {
       if (cursor + dataBatchSize <= numData) {
-        data.map(ndArray => {ndArray.slice(cursor, cursor + dataBatchSize)}).toIndexedSeq
+        data.map { case (_, ndArray) => ndArray.slice(cursor, cursor + dataBatchSize) }
       } else {
         // padding
-        data.map(_padData).toIndexedSeq
+        data.map { case (_, ndArray) => _padData(ndArray) }
       }
     }
   }
@@ -173,7 +190,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
    * @return the data of current batch
    */
   override def getData(): IndexedSeq[NDArray] = {
-    _getData(_dataList)
+    _getData(initData)
   }
 
   /**
@@ -181,7 +198,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
    * @return the label of current batch
    */
   override def getLabel(): IndexedSeq[NDArray] = {
-    _getData(_labelList)
+    _getData(initLabel)
   }
 
   /**
@@ -189,7 +206,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
    * @return
    */
   override def getIndex(): IndexedSeq[Long] = {
-    (cursor.toLong to (cursor + dataBatchSize).toLong).toIndexedSeq
+    cursor.toLong to (cursor + dataBatchSize).toLong
   }
 
   /**
@@ -213,3 +230,66 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
 
   override def batchSize: Int = dataBatchSize
 }
+
+object NDArrayIter {
+
+  /**
+   * Builder class for NDArrayIter.
+   */
+  class Builder() {
+    private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
+    private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
+    private var dataBatchSize: Int = 1
+    private var lastBatchHandle: String = "pad"
+
+    /**
+     * Add one data input with its name.
+     * @param name Data name.
+     * @param data Data nd-array.
+     * @return The builder object itself.
+     */
+    def addData(name: String, data: NDArray): Builder = {
+      this.data = this.data ++ IndexedSeq((name, data))
+      this
+    }
+
+    /**
+     * Add one label input with its name.
+     * @param name Label name.
+     * @param label Label nd-array.
+     * @return The builder object itself.
+     */
+    def addLabel(name: String, label: NDArray): Builder = {
+      this.label = this.label ++ IndexedSeq((name, label))
+      this
+    }
+
+    /**
+     * Set the batch size of the iterator.
+     * @param batchSize batch size.
+     * @return The builder object itself.
+     */
+    def setBatchSize(batchSize: Int): Builder = {
+      this.dataBatchSize = batchSize
+      this
+    }
+
+    /**
+     * How to handle the last batch.
+     * @param lastBatchHandle Can be "pad", "discard" or "roll_over".
+     * @return The builder object itself.
+     */
+    def setLastBatchHandle(lastBatchHandle: String): Builder = {
+      this.lastBatchHandle = lastBatchHandle
+      this
+    }
+
+    /**
+     * Build the NDArrayIter object.
+     * @return the built object.
+     */
+    def build(): NDArrayIter = {
+      new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle)
+    }
+  }
+}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 0f4b7c0e7a3..1b922b3c05b 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -24,7 +24,7 @@ import scala.sys.process._
 
 class IOSuite extends FunSuite with BeforeAndAfterAll {
 
-  private var tu = new TestUtil
+  private val tu = new TestUtil
 
   test("test MNISTIter & MNISTPack") {
     // get data
@@ -258,7 +258,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     assert(batchCount === nBatch0)
 
     // test discard
-    val dataIter1 = new NDArrayIter(data, label, 128, false, "discard")
+    val dataIter1 = new NDArrayIter.Builder()
+      .addData("data0", data(0)).addData("data1", data(1))
+      .addLabel("label", label(0))
+      .setBatchSize(128)
+      .setLastBatchHandle("discard").build()
     val nBatch1 = 7
     batchCount = 0
     while(dataIter1.hasNext) {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services