You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/06/15 23:50:11 UTC
[incubator-mxnet] branch master updated: [MXNET-539] Allow Scala
users to specify data/label names for NDArrayIter (#11256)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 02e8a71 [MXNET-539] Allow Scala users to specify data/label names for NDArrayIter (#11256)
02e8a71 is described below
commit 02e8a7199fad36b1b4b3da58b44310e5808c7378
Author: Yizhi Liu <li...@apache.org>
AuthorDate: Fri Jun 15 16:50:01 2018 -0700
[MXNET-539] Allow Scala users to specify data/label names for NDArrayIter (#11256)
* improve NDArrayIter to have Builder and ability to specifying names
---
.../scala/org/apache/mxnet/io/NDArrayIter.scala | 134 ++++++++++++++++-----
.../src/test/scala/org/apache/mxnet/IOSuite.scala | 8 +-
2 files changed, 113 insertions(+), 29 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 5108938..70c6487 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -28,7 +28,8 @@ import scala.collection.immutable.ListMap
/**
* NDArrayIter object in mxnet. Taking NDArray to get dataiter.
*
- * @param data NDArrayIter supports single or multiple data and label.
+ * @param data Specify the data as well as the name.
+ * NDArrayIter supports single or multiple data and label.
* @param label Same as data, but is not fed to the model during testing.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
@@ -38,15 +39,35 @@ import scala.collection.immutable.ListMap
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
-class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
- private val dataBatchSize: Int = 1, shuffle: Boolean = false,
- lastBatchHandle: String = "pad",
- dataName: String = "data", labelName: String = "label") extends DataIter {
- private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
+class NDArrayIter(data: IndexedSeq[(String, NDArray)],
+ label: IndexedSeq[(String, NDArray)],
+ private val dataBatchSize: Int, shuffle: Boolean,
+ lastBatchHandle: String) extends DataIter {
+
+ /**
+ * @param data Specify the data. Data names will be data_0, data_1, ..., etc.
+ * @param label Same as data, but is not fed to the model during testing.
+ * Label names will be label_0, label_1, ..., etc.
+ * @param dataBatchSize Batch Size
+ * @param shuffle Whether to shuffle the data
+ * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
+ *
+ * This iterator will pad, discard or roll over the last batch if
+ * the size of data does not match batch_size. Roll over is intended
+ * for training and can cause problems if used for prediction.
+ */
+ def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
+ dataBatchSize: Int = 1, shuffle: Boolean = false,
+ lastBatchHandle: String = "pad",
+ dataName: String = "data", labelName: String = "label") {
+ this(IO.initData(data, allowEmpty = false, dataName),
+ IO.initData(label, allowEmpty = true, labelName),
+ dataBatchSize, shuffle, lastBatchHandle)
+ }
+ private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
- private val (_dataList: IndexedSeq[NDArray],
- _labelList: IndexedSeq[NDArray]) = {
+ val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = {
// data should not be null and size > 0
require(data != null && data.size > 0,
"data should not be null and data.size should not be zero")
@@ -55,17 +76,17 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
"label should not be null. Use IndexedSeq.empty if there are no labels")
// shuffle is not supported currently
- require(shuffle == false, "shuffle is not supported currently")
+ require(!shuffle, "shuffle is not supported currently")
// discard final part if lastBatchHandle equals discard
if (lastBatchHandle.equals("discard")) {
- val dataSize = data(0).shape(0)
+ val dataSize = data(0)._2.shape(0)
require(dataBatchSize <= dataSize,
"batch_size need to be smaller than data size when not padding.")
val keepSize = dataSize - dataSize % dataBatchSize
- val dataList = data.map(ndArray => {ndArray.slice(0, keepSize)})
+ val dataList = data.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) }
if (!label.isEmpty) {
- val labelList = label.map(ndArray => {ndArray.slice(0, keepSize)})
+ val labelList = label.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) }
(dataList, labelList)
} else {
(dataList, label)
@@ -75,13 +96,9 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
}
}
-
- val initData: IndexedSeq[(String, NDArray)] = IO.initData(_dataList, false, dataName)
- val initLabel: IndexedSeq[(String, NDArray)] = IO.initData(_labelList, true, labelName)
- val numData = _dataList(0).shape(0)
- val numSource = initData.size
- var cursor = -dataBatchSize
-
+ val numData = initData(0)._2.shape(0)
+ val numSource: MXUint = initData.size
+ private var cursor = -dataBatchSize
private val (_provideData: ListMap[String, Shape],
_provideLabel: ListMap[String, Shape]) = {
@@ -112,8 +129,8 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
* reset the iterator
*/
override def reset(): Unit = {
- if (lastBatchHandle.equals("roll_over") && cursor>numData) {
- cursor = -dataBatchSize + (cursor%numData)%dataBatchSize
+ if (lastBatchHandle.equals("roll_over") && cursor > numData) {
+ cursor = -dataBatchSize + (cursor%numData) % dataBatchSize
} else {
cursor = -dataBatchSize
}
@@ -154,16 +171,16 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
newArray
}
- private def _getData(data: IndexedSeq[NDArray]): IndexedSeq[NDArray] = {
+ private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = {
require(cursor < numData, "DataIter needs reset.")
if (data == null) {
null
} else {
if (cursor + dataBatchSize <= numData) {
- data.map(ndArray => {ndArray.slice(cursor, cursor + dataBatchSize)}).toIndexedSeq
+ data.map { case (_, ndArray) => ndArray.slice(cursor, cursor + dataBatchSize) }
} else {
// padding
- data.map(_padData).toIndexedSeq
+ data.map { case (_, ndArray) => _padData(ndArray) }
}
}
}
@@ -173,7 +190,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
* @return the data of current batch
*/
override def getData(): IndexedSeq[NDArray] = {
- _getData(_dataList)
+ _getData(initData)
}
/**
@@ -181,7 +198,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
* @return the label of current batch
*/
override def getLabel(): IndexedSeq[NDArray] = {
- _getData(_labelList)
+ _getData(initLabel)
}
/**
@@ -189,7 +206,7 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
* @return
*/
override def getIndex(): IndexedSeq[Long] = {
- (cursor.toLong to (cursor + dataBatchSize).toLong).toIndexedSeq
+ cursor.toLong to (cursor + dataBatchSize).toLong
}
/**
@@ -213,3 +230,66 @@ class NDArrayIter (data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = Index
override def batchSize: Int = dataBatchSize
}
+
+object NDArrayIter {
+
+ /**
+ * Builder class for NDArrayIter.
+ */
+ class Builder() {
+ private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
+ private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
+ private var dataBatchSize: Int = 1
+ private var lastBatchHandle: String = "pad"
+
+ /**
+ * Add one data input with its name.
+ * @param name Data name.
+ * @param data Data nd-array.
+ * @return The builder object itself.
+ */
+ def addData(name: String, data: NDArray): Builder = {
+ this.data = this.data ++ IndexedSeq((name, data))
+ this
+ }
+
+ /**
+ * Add one label input with its name.
+ * @param name Label name.
+ * @param label Label nd-array.
+ * @return The builder object itself.
+ */
+ def addLabel(name: String, label: NDArray): Builder = {
+ this.label = this.label ++ IndexedSeq((name, label))
+ this
+ }
+
+ /**
+ * Set the batch size of the iterator.
+ * @param batchSize batch size.
+ * @return The builder object itself.
+ */
+ def setBatchSize(batchSize: Int): Builder = {
+ this.dataBatchSize = batchSize
+ this
+ }
+
+ /**
+ * How to handle the last batch.
+ * @param lastBatchHandle Can be "pad", "discard" or "roll_over".
+ * @return The builder object itself.
+ */
+ def setLastBatchHandle(lastBatchHandle: String): Builder = {
+ this.lastBatchHandle = lastBatchHandle
+ this
+ }
+
+ /**
+ * Build the NDArrayIter object.
+ * @return the built object.
+ */
+ def build(): NDArrayIter = {
+ new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle)
+ }
+ }
+}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 0f4b7c0..1b922b3 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -24,7 +24,7 @@ import scala.sys.process._
class IOSuite extends FunSuite with BeforeAndAfterAll {
- private var tu = new TestUtil
+ private val tu = new TestUtil
test("test MNISTIter & MNISTPack") {
// get data
@@ -258,7 +258,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(batchCount === nBatch0)
// test discard
- val dataIter1 = new NDArrayIter(data, label, 128, false, "discard")
+ val dataIter1 = new NDArrayIter.Builder()
+ .addData("data0", data(0)).addData("data1", data(1))
+ .addLabel("label", label(0))
+ .setBatchSize(128)
+ .setLastBatchHandle("discard").build()
val nBatch1 = 7
batchCount = 0
while(dataIter1.hasNext) {
--
To stop receiving notification emails like this one, please contact
nswamy@apache.org.