You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by li...@apache.org on 2018/05/23 07:48:17 UTC

[incubator-mxnet] branch v1.2.0-java updated (c887376 -> 9a3cccf)

This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a change to branch v1.2.0-java
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git.


 discard c887376  add Builder and varargs which are java-friendly
     new 9a3cccf  add Builder and varargs to be java-friendly

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (c887376)
            \
             N -- N -- N   refs/heads/v1.2.0-java (9a3cccf)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:

-- 
To stop receiving notification emails like this one, please contact
liuyizhi@apache.org.

[incubator-mxnet] 01/01: add Builder and varargs to be java-friendly

Posted by li...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch v1.2.0-java
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 9a3cccf92a634395364d775af95a42003548f971
Author: Yizhi Liu <yi...@amazon.com>
AuthorDate: Mon May 21 13:48:31 2018 -0700

    add Builder and varargs to be java-friendly
---
 .../core/src/main/scala/org/apache/mxnet/IO.scala  | 56 +++++++++++++++++++++-
 .../src/main/scala/org/apache/mxnet/NDArray.scala  |  3 +-
 .../src/main/scala/org/apache/mxnet/Shape.scala    |  4 ++
 .../src/main/scala/org/apache/mxnet/Symbol.scala   |  1 -
 .../scala/org/apache/mxnet/module/BaseModule.scala | 30 ++++++++++++
 .../scala/org/apache/mxnet/module/Module.scala     | 43 ++++++++++++++++-
 .../main/scala/org/apache/mxnet/NDArrayMacro.scala | 52 ++------------------
 7 files changed, 138 insertions(+), 51 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
index 7a9c1a7..123e2f8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
@@ -19,9 +19,10 @@ package org.apache.mxnet
 
 import org.apache.mxnet.Base._
 import org.apache.mxnet.DType.DType
-import org.apache.mxnet.io.{MXDataPack, MXDataIter}
+import org.apache.mxnet.io.{MXDataIter, MXDataPack}
 import org.slf4j.LoggerFactory
 
+import scala.annotation.varargs
 import scala.collection.immutable.ListMap
 import scala.collection.mutable.ListBuffer
 
@@ -140,6 +141,7 @@ class DataBatch(val data: IndexedSeq[NDArray],
                 // (must match the order of input data/label)
                 private val providedData: ListMap[String, Shape] = null,
                 private val providedLabel: ListMap[String, Shape] = null) {
+
   /**
    * Dispose its data and labels
    * The object shall never be used after it is disposed.
@@ -160,6 +162,58 @@ class DataBatch(val data: IndexedSeq[NDArray],
   def provideLabel: ListMap[String, Shape] = providedLabel
 }
 
+object DataBatch {
+  class Builder() {
+    private var data: IndexedSeq[NDArray] = null
+    private var label: IndexedSeq[NDArray] = null
+    private var index: IndexedSeq[Long] = null
+    private var pad: Int = 0
+    private var bucketKey: AnyRef = null
+    private var providedData: ListMap[String, Shape] = ListMap.empty
+    private var providedLabel: ListMap[String, Shape] = ListMap.empty
+
+    @varargs def setData(data: NDArray*): Builder = {
+      this.data = data.toIndexedSeq
+      this
+    }
+
+    @varargs def setLabel(label: NDArray*): Builder = {
+      this.label = label.toIndexedSeq
+      this
+    }
+
+    @varargs def setIndex(index: Long*): Builder = {
+      this.index = index.toIndexedSeq
+      this
+    }
+
+    def setPad(pad: Int): Builder = {
+      this.pad = pad
+      this
+    }
+
+    def setBucketKey(bucketKey: AnyRef): Builder = {
+      this.bucketKey = bucketKey
+      this
+    }
+
+    def provideData(name: String, shape: Shape): Builder = {
+      providedData = providedData.updated(name, shape)
+      this
+    }
+
+    def provideLabel(name: String, shape: Shape): Builder = {
+      providedLabel = providedLabel.updated(name, shape)
+      this
+    }
+
+    def build(): DataBatch = {
+      new DataBatch(data, label, index, pad,
+        bucketKey, providedData, providedLabel)
+    }
+  }
+}
+
 /**
  * DataIter object in mxnet.
  */
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 416f2d7..e8c687e 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -48,6 +48,7 @@ object NDArray {
     }
   }
 
+  //  private[mxnet] def genericNDArrayFunctionInvoke(
   /**
    * Used by NDArrayMacro.
    * Invoke this function by passing in parameters.
@@ -57,7 +58,7 @@ object NDArray {
    * @param kwargs Key-value arguments of input scalars
    * @return The result NDArrays of result of computation.
    */
-  private[mxnet] def genericNDArrayFunctionInvoke(
+  def genericNDArrayFunctionInvoke(
     funcName: String, args: Seq[Any], kwargs: Map[String, Any] = null): NDArrayFuncReturn = {
     val function = functions(funcName)
     val ndArgs = ArrayBuffer.empty[NDArray]
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
index e632ade..6891762 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
@@ -17,6 +17,8 @@
 
 package org.apache.mxnet
 
+import scala.annotation.varargs
+
 /**
  * Shape of [[NDArray]] or other data
  */
@@ -28,6 +30,7 @@ class Shape(dims: Traversable[Int]) extends Serializable {
   }
 
   def apply(dim: Int): Int = shape(dim)
+  def get(dim: Int): Int = apply(dim)
   def size: Int = shape.size
   def length: Int = shape.length
   def drop(dim: Int): Shape = new Shape(shape.drop(dim))
@@ -56,4 +59,5 @@ class Shape(dims: Traversable[Int]) extends Serializable {
 object Shape {
   def apply(dims: Int *): Shape = new Shape(dims: _*)
   def apply(dims: Traversable[Int]): Shape = new Shape(dims)
+  @varargs def create(dims: Int*): Shape = new Shape(dims)
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 13f85a7..b6947b4 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -101,7 +101,6 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
     var index: Int = -1
     for ((output, i) <- listOutputs().view.zipWithIndex) {
       if (output == name) {
-        require(index == -1, s"There are multiple outputs with name $name")
         index = i
       }
     }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
index 108cff4..f7ae883 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
@@ -23,6 +23,8 @@ import org.apache.mxnet.optimizer.SGD
 import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 import org.slf4j.Logger
+
+import scala.annotation.varargs
 import scala.collection.mutable.ArrayBuffer
 
 object BaseModule {
@@ -468,6 +470,10 @@ abstract class BaseModule {
    */
   def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit
 
+  def forward(dataBatch: DataBatch, isTrain: Boolean): Unit = {
+    forward(dataBatch, Option(isTrain))
+  }
+
   /**
    * Backward computation.
    * @param outGrads Gradient on the outputs to be propagated back.
@@ -549,6 +555,30 @@ abstract class BaseModule {
            forceRebind: Boolean = false, sharedModule: Option[BaseModule] = None,
            gradReq: String = "write"): Unit
 
+
+  protected var labelShapesPartial: IndexedSeq[DataDesc] = _
+  protected var sharedModulePartial: BaseModule = _
+  protected var gradReqPartial: String = "write"
+  @varargs def bindPartial(labelShape: DataDesc*): BaseModule = {
+    labelShapesPartial = labelShape.toIndexedSeq
+    this
+  }
+  def bindPartial(sharedModule: BaseModule): BaseModule = {
+    sharedModulePartial = sharedModule
+    this
+  }
+  def bindPartial(gradReq: String): BaseModule = {
+    gradReqPartial = gradReq
+    this
+  }
+
+  @varargs def bind(forTraining: Boolean, inputsNeedGrad: Boolean,
+                    forceRebind: Boolean, dataShape: DataDesc*): Unit = {
+    bind(dataShape.toVector, Option(labelShapesPartial),
+      forTraining, inputsNeedGrad, forceRebind,
+      Option(sharedModulePartial), gradReqPartial)
+  }
+
   // Install and initialize optimizers.
   def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new SGD(),
                     resetOptimizer: Boolean = true, forceInit: Boolean = false): Unit
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
index ac3d645..a46b605 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
@@ -17,13 +17,16 @@
 
 package org.apache.mxnet.module
 
-import java.io.{FileInputStream, BufferedInputStream, BufferedOutputStream, FileOutputStream}
+import java.io.{BufferedInputStream, BufferedOutputStream, FileInputStream, FileOutputStream}
+
 import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 import org.apache.mxnet.module.DataParallelExecutorGroup.Builder
 import org.apache.mxnet.optimizer.SGD
 import org.slf4j.LoggerFactory
 
+import scala.annotation.varargs
+
 /**
  * Module is a basic module that wrap a `Symbol`. It is functionally the same
  * as the `FeedForward` model, except under the module API.
@@ -642,4 +645,42 @@ object Module {
     }
     mod
   }
+
+  class Builder (private val modelDef: Symbol) {
+    private var dataNames: IndexedSeq[String] = IndexedSeq("data")
+    private var labelNames: IndexedSeq[String] = IndexedSeq("softmax_label")
+    private var contexts: Array[Context] = Array(Context.cpu())
+    private var workLoadList: IndexedSeq[Float] = _
+    private var fixedParamNames: Set[String] = _
+
+    @varargs def setContext(ctx: Context*): Builder = {
+      contexts = ctx.toArray
+      this
+    }
+
+    @varargs def setDataNames(name: String*): Builder = {
+      dataNames = name.toVector
+      this
+    }
+
+    @varargs def setLabelNames(name: String*): Builder = {
+      labelNames = name.toVector
+      this
+    }
+
+    @varargs def setWorkLoadList(workload: Float*): Builder = {
+      workLoadList = workload.toVector
+      this
+    }
+
+    @varargs def setFixedParamNames(name: String*): Builder = {
+      fixedParamNames = name.toSet
+      this
+    }
+
+    def build(): Module = {
+      new Module(modelDef, dataNames, labelNames, contexts,
+                 Option(workLoadList), Option(fixedParamNames))
+    }
+  }
 }
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index c26d14c..c4d16bc 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -52,18 +52,6 @@ private[mxnet] object NDArrayMacro {
       else ndarrayFunctions.filter(!_._1.startsWith("_contrib_"))
     }
 
-    val AST_NDARRAY_TYPE = Select(Select(Select(
-      Ident(TermName("org")), TermName("apache")), TermName("mxnet")), TypeName("NDArray"))
-    val AST_TYPE_MAP_STRING_ANY = AppliedTypeTree(Ident(TypeName("Map")),
-      List(Ident(TypeName("String")), Ident(TypeName("Any"))))
-    val AST_TYPE_ANY_VARARG = AppliedTypeTree(
-      Select(
-        Select(Ident(termNames.ROOTPKG), TermName("scala")),
-        TypeName("<repeated>")
-      ),
-      List(Ident(TypeName("Any")))
-    )
-
     val functionDefs = newNDArrayFunctions flatMap { case (funcName, funcProp) =>
       val functionScope = {
         if (isContrib) Modifiers()
@@ -75,45 +63,15 @@ private[mxnet] object NDArrayMacro {
         if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length())
         else funcName
       }
-
+      val termName = TermName(funcName)
       // It will generate definition something like,
       Seq(
+        // scalastyle:off
         // def transpose(kwargs: Map[String, Any] = null)(args: Any*)
-        DefDef(functionScope, TermName(newName), List(),
-          List(
-            List(
-              ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("kwargs"),
-                AST_TYPE_MAP_STRING_ANY, Literal(Constant(null)))
-            ),
-            List(
-              ValDef(Modifiers(), TermName("args"), AST_TYPE_ANY_VARARG, EmptyTree)
-            )
-          ), TypeTree(),
-          Apply(
-            Ident(TermName("genericNDArrayFunctionInvoke")),
-            List(
-              Literal(Constant(funcName)),
-              Ident(TermName("args")),
-              Ident(TermName("kwargs"))
-            )
-          )
-        ),
+        q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}",
         // def transpose(args: Any*)
-        DefDef(functionScope, TermName(newName), List(),
-          List(
-            List(
-              ValDef(Modifiers(), TermName("args"), AST_TYPE_ANY_VARARG, EmptyTree)
-            )
-          ), TypeTree(),
-          Apply(
-            Ident(TermName("genericNDArrayFunctionInvoke")),
-            List(
-              Literal(Constant(funcName)),
-              Ident(TermName("args")),
-              Literal(Constant(null))
-            )
-          )
-        )
+        q"@scala.annotation.varargs def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}"
+        // scalastyle:on
       )
     }
 

-- 
To stop receiving notification emails like this one, please contact
liuyizhi@apache.org.