You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/27 21:07:32 UTC

[GitHub] yzhliu closed pull request #11045: [MXNET-471] Add Builder class for Scala Module and DataBatch to simplify construction

yzhliu closed pull request #11045: [MXNET-471] Add Builder class for Scala Module and DataBatch to simplify construction
URL: https://github.com/apache/incubator-mxnet/pull/11045
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
index 7a9c1a76e6f..d9c767cb1fa 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
@@ -19,9 +19,10 @@ package org.apache.mxnet
 
 import org.apache.mxnet.Base._
 import org.apache.mxnet.DType.DType
-import org.apache.mxnet.io.{MXDataPack, MXDataIter}
+import org.apache.mxnet.io.{MXDataIter, MXDataPack}
 import org.slf4j.LoggerFactory
 
+import scala.annotation.varargs
 import scala.collection.immutable.ListMap
 import scala.collection.mutable.ListBuffer
 
@@ -160,6 +161,108 @@ class DataBatch(val data: IndexedSeq[NDArray],
   def provideLabel: ListMap[String, Shape] = providedLabel
 }
 
+object DataBatch {
+  /**
+   * Builder class for DataBatch.
+   */
+  class Builder() {
+    private var data: IndexedSeq[NDArray] = null
+    private var label: IndexedSeq[NDArray] = null
+    private var index: IndexedSeq[Long] = null
+    private var pad: Int = 0
+    private var bucketKey: AnyRef = null
+    private var datatShapes: ListMap[String, Shape] = null
+    private var labelShapes: ListMap[String, Shape] = null
+
+    /**
+     * Set the input data.
+     * @param data a list of data.
+     * @return this.
+     */
+    @varargs def setData(data: NDArray*): Builder = {
+      this.data = data.toIndexedSeq
+      this
+    }
+
+    /**
+     * Set the labels in the same order of data.
+     * @param label a list of labels.
+     * @return this.
+     */
+    @varargs def setLabel(label: NDArray*): Builder = {
+      this.label = label.toIndexedSeq
+      this
+    }
+
+    /**
+     * Set the example indices in this batch.
+     * @param index indices in the same order of data.
+     * @return this.
+     */
+    @varargs def setIndex(index: Long*): Builder = {
+      this.index = index.toIndexedSeq
+      this
+    }
+
+    /**
+     * Set the pad.
+     * @param pad The number of examples padded at the end of a batch. It is used when the
+     *            total number of examples read is not divisible by the `batch_size`.
+     *            These extra padded examples are ignored in prediction.
+     * @return this
+     */
+    def setPad(pad: Int): Builder = {
+      this.pad = pad
+      this
+    }
+
+    /**
+     * Set the bucket key, used for bucketing module.
+     * @param bucketKey the bucket key related to this batch.
+     * @return this.
+     */
+    def setBucketKey(bucketKey: AnyRef): Builder = {
+      this.bucketKey = bucketKey
+      this
+    }
+
+    /**
+     * Provide the shape of a data.
+     * @param name data name.
+     * @param shape data shape.
+     * @return this.
+     */
+    def provideDataShape(name: String, shape: Shape): Builder = {
+      if (datatShapes == null) {
+        datatShapes = ListMap((name, shape))
+      } else {
+        datatShapes = datatShapes.updated(name, shape)
+      }
+      this
+    }
+
+    /**
+     * Provide the shape of a label.
+     * @param name label name.
+     * @param shape label shape.
+     * @return this.
+     */
+    def provideLabelShape(name: String, shape: Shape): Builder = {
+      if (labelShapes == null) {
+        labelShapes = ListMap((name, shape))
+      } else {
+        labelShapes = labelShapes.updated(name, shape)
+      }
+      this
+    }
+
+    def build(): DataBatch = {
+      require(data != null, "data is required.")
+      new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes)
+    }
+  }
+}
+
 /**
  * DataIter object in mxnet.
  */
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
index e632ade808e..68917621772 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
@@ -17,6 +17,8 @@
 
 package org.apache.mxnet
 
+import scala.annotation.varargs
+
 /**
  * Shape of [[NDArray]] or other data
  */
@@ -28,6 +30,7 @@ class Shape(dims: Traversable[Int]) extends Serializable {
   }
 
   def apply(dim: Int): Int = shape(dim)
+  def get(dim: Int): Int = apply(dim)
   def size: Int = shape.size
   def length: Int = shape.length
   def drop(dim: Int): Shape = new Shape(shape.drop(dim))
@@ -56,4 +59,5 @@ class Shape(dims: Traversable[Int]) extends Serializable {
 object Shape {
   def apply(dims: Int *): Shape = new Shape(dims: _*)
   def apply(dims: Traversable[Int]): Shape = new Shape(dims)
+  @varargs def create(dims: Int*): Shape = new Shape(dims)
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 60efd2ba62b..a17fe57dde6 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -101,7 +101,6 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
     var index: Int = -1
     for ((output, i) <- listOutputs().view.zipWithIndex) {
       if (output == name) {
-        require(index == -1, s"There are multiple outputs with name $name")
         index = i
       }
     }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
index 108cff44965..60b80f25285 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
@@ -23,6 +23,8 @@ import org.apache.mxnet.optimizer.SGD
 import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 import org.slf4j.Logger
+
+import scala.annotation.varargs
 import scala.collection.mutable.ArrayBuffer
 
 object BaseModule {
@@ -468,6 +470,15 @@ abstract class BaseModule {
    */
   def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit
 
+  /**
+   * Forward computation.
+   * @param dataBatch a batch of data.
+   * @param isTrain Whether it is for training or not.
+   */
+  def forward(dataBatch: DataBatch, isTrain: Boolean): Unit = {
+    forward(dataBatch, Option(isTrain))
+  }
+
   /**
    * Backward computation.
    * @param outGrads Gradient on the outputs to be propagated back.
@@ -549,6 +560,25 @@ abstract class BaseModule {
            forceRebind: Boolean = false, sharedModule: Option[BaseModule] = None,
            gradReq: String = "write"): Unit
 
+
+ /**
+  * Bind the symbols to construct executors.
+  * This is necessary before one can perform computation with the module.
+  * @param forTraining Default is `True`. Whether the executors should be bind for training.
+  * @param inputsNeedGrad  Default is `False`.
+  *                        Whether the gradients to the input data need to be computed.
+  *                        Typically this is not needed.
+  *                        But this might be needed when implementing composition of modules.
+  * @param forceRebind Default is `False`. This function does nothing
+  *                    if the executors are already binded. But with this `True`,
+  *                    the executors will be forced to rebind.
+  * @param dataShape Typically is `DataIter.provideData`.
+  */
+  @varargs def bind(forTraining: Boolean, inputsNeedGrad: Boolean,
+                    forceRebind: Boolean, dataShape: DataDesc*): Unit = {
+    bind(dataShape.toVector, None, forTraining, inputsNeedGrad, forceRebind, None)
+  }
+
   // Install and initialize optimizers.
   def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new SGD(),
                     resetOptimizer: Boolean = true, forceInit: Boolean = false): Unit
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
index ac3d645b333..d55a42653ce 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
@@ -17,13 +17,16 @@
 
 package org.apache.mxnet.module
 
-import java.io.{FileInputStream, BufferedInputStream, BufferedOutputStream, FileOutputStream}
+import java.io.{BufferedInputStream, BufferedOutputStream, FileInputStream, FileOutputStream}
+
 import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 import org.apache.mxnet.module.DataParallelExecutorGroup.Builder
 import org.apache.mxnet.optimizer.SGD
 import org.slf4j.LoggerFactory
 
+import scala.annotation.varargs
+
 /**
  * Module is a basic module that wrap a `Symbol`. It is functionally the same
  * as the `FeedForward` model, except under the module API.
@@ -642,4 +645,72 @@ object Module {
     }
     mod
   }
+
+  /**
+   * Builder class for Module.
+   * @param modelDef model definition in Symbol.
+   */
+  class Builder(private val modelDef: Symbol) {
+    private var dataNames: IndexedSeq[String] = IndexedSeq("data")
+    private var labelNames: IndexedSeq[String] = IndexedSeq("softmax_label")
+    private var contexts: Array[Context] = Array(Context.cpu())
+    private var workLoadList: IndexedSeq[Float] = _
+    private var fixedParamNames: Set[String] = _
+
+    /**
+     * Set the context for execution.
+     * @param ctx a list of contexts.
+     * @return this.
+     */
+    @varargs def setContext(ctx: Context*): Builder = {
+      contexts = ctx.toArray
+      this
+    }
+
+    /**
+     * Set the input data names.
+     * @param name a list of data names. Cannot be null.
+     * @return this.
+     */
+    @varargs def setDataNames(name: String*): Builder = {
+      dataNames = name.toVector
+      this
+    }
+
+    /**
+     * Set the label names.
+     * @param name a list of label names.
+     *             Set to null if no label is required.
+     * @return this.
+     */
+    @varargs def setLabelNames(name: String*): Builder = {
+      labelNames = if (name == null) IndexedSeq.empty[String] else name.toVector
+      this
+    }
+
+    /**
+     * Set the workloads.
+     * @param workloads a list of workloads
+     * @return this.
+     */
+    @varargs def setWorkLoadList(workloads: Float*): Builder = {
+      workLoadList = workloads.toVector
+      this
+    }
+
+    /**
+     * Specify the parameters need to be fixed.
+     * @param name a list of parameter names.
+     * @return this.
+     */
+    @varargs def setFixedParamNames(name: String*): Builder = {
+      fixedParamNames = name.toSet
+      this
+    }
+
+    def build(): Module = {
+      new Module(modelDef, dataNames, labelNames, contexts,
+        Option(workLoadList), Option(fixedParamNames))
+    }
+  }
 }
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index a9cac131dd2..22b9c3bdaf3 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -18,7 +18,6 @@
 package org.apache.mxnet
 
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
-import org.apache.mxnet.CheckUtils._
 import org.apache.mxnet.module._
 import org.apache.mxnet.optimizer._
 import org.apache.mxnet.io._
@@ -52,8 +51,11 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     import SymbolConversions._
     c = a + 2 * b + 3 * c
 
-    val mod = new Module(c, IndexedSeq("b", "c", "a"), null,
-      contexts = Array(Context.cpu(0), Context.cpu(1)))
+    val mod = new Module.Builder(c)
+      .setDataNames("b", "c", "a")
+      .setLabelNames(null)
+      .setContext(Context.cpu(0), Context.cpu(1))
+      .build()
     mod.bind(dataShapes = IndexedSeq(
       DataDesc("b", Shape(5, 5), layout = "NT"),
       DataDesc("c", Shape(5, 5), layout = "NT"),
@@ -342,11 +344,13 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     dShape1 = Shape(20, 3, 120, 120)
     dShape2 = Shape(20, 3, 32, 64)
     lShape = Shape(20)
-    dataBatch = new DataBatch(
-      data = IndexedSeq(
+    dataBatch = new DataBatch.Builder()
+      .setData(
         NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(),
-        NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> dShape2.toString()))()),
-      label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+        NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> dShape2.toString()))())
+      .setLabel(NDArray.ones(lShape))
+      .setPad(0)
+      .build()
     mod.forward(dataBatch)
     assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
     mod.backward()
@@ -355,11 +359,13 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     dShape1 = Shape(5, 3, 28, 40)
     dShape2 = Shape(5, 3, 24, 16)
     lShape = Shape(5)
-    dataBatch = new DataBatch(
-      data = IndexedSeq(
+    dataBatch = new DataBatch.Builder()
+      .setData(
         NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(),
-        NDArray.random_uniform(Map("low" -> 15, "high" -> 25, "shape" -> dShape2.toString()))()),
-      label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+        NDArray.random_uniform(Map("low" -> 15, "high" -> 25, "shape" -> dShape2.toString()))())
+      .setLabel(NDArray.ones(lShape))
+      .setPad(0)
+      .build()
     mod.forward(dataBatch)
     assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
     mod.backward()
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index bbe786f5a0a..a6e7b5fda55 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -66,6 +66,9 @@ private[mxnet] object NDArrayMacro {
         if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) {
           Seq(
             // scalastyle:off
+            // (yizhi) We are investigating a way to make these functions type-safe
+            // and waiting to see the new approach is stable enough.
+            // Thus these functions may be deprecated in the future.
             // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
             q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
             // e.g def transpose(args: Any*)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services