You are viewing a plain text version of this content. The canonical link for it is here.
Posted to by GitBox <> on 2018/11/01 21:17:25 UTC

[GitHub] mdespriee commented on a change in pull request #13038: [MXNET-918] Introduce Random module / Refact code generation

mdespriee commented on a change in pull request #13038: [MXNET-918] Introduce Random module / Refact code generation

 File path: scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
 @@ -30,207 +27,111 @@ private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnot
 private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation {
-  private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs
+  private[mxnet] def macroTransform(annottees: Any*) = macro TypedNDArrayAPIMacro.typeSafeAPIDefs
-private[mxnet] object NDArrayMacro {
-  case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
-  case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
-  // scalastyle:off havetype
-  def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    impl(c)(annottees: _*)
-  }
-  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    typeSafeAPIImpl(c)(annottees: _*)
-  }
-  // scalastyle:off havetype
-  private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
+private[mxnet] object NDArrayMacro extends GeneratorBase {
-  private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
+  def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
     val isContrib: Boolean = c.prefix.tree match {
       case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b))
-    val newNDArrayFunctions = {
-      if (isContrib) ndarrayFunctions.filter("_contrib_"))
-      else ndarrayFunctions.filterNot("_"))
-    }
-     val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
-        val funcName =
-        val termName = TermName(funcName)
-       Seq(
-            // scalastyle:off
-            // (yizhi) We are investigating a way to make these functions type-safe
-            // and waiting to see the new approach is stable enough.
-            // Thus these functions may be deprecated in the future.
-            // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
-            q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
-            // e.g def transpose(args: Any*)
-            q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
-            // scalastyle:on
-          )
-        }
-    structGeneration(c)(functionDefs, annottees : _*)
+    impl(c)(isContrib, annottees: _*)
-  private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
+  private def impl(c: blackbox.Context)
+                  (isContrib: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
-    val isContrib: Boolean = c.prefix.tree match {
-      case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
-    }
-    // Defines Operators that should not generated
-    val notGenerated = Set("Custom")
-    val newNDArrayFunctions = {
-      if (isContrib) ndarrayFunctions.filter(
-        func =>"_contrib_") || !"_"))
-      else ndarrayFunctions.filterNot("_"))
-    }.filterNot(ele => notGenerated.contains(
-    val functionDefs = { ndarrayfunction =>
-      // Construct argument field
-      var argDef = ListBuffer[String]()
-      // Construct Implementation field
-      var impl = ListBuffer[String]()
-      impl += "val map = scala.collection.mutable.Map[String, Any]()"
-      impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
-      ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
-        // var is a special word used to define variable in Scala,
-        // need to changed to something else in order to make it work
-        val currArgName = ndarrayarg.argName match {
-          case "var" => "vari"
-          case "type" => "typeOf"
-          case default => ndarrayarg.argName
-        }
-        if (ndarrayarg.isOptional) {
-          argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
-        }
-        else {
-          argDef += s"${currArgName} : ${ndarrayarg.argType}"
-        }
-        // NDArray arg implementation
-        val returnType = "org.apache.mxnet.NDArray"
-        // TODO: Currently we do not add place holder for NDArray
-        // Example: an NDArray operator like the following format
-        // NDArray(required), arg2: NDArray(Optional), arg3: NDArray(Optional)
-        // If we place, arg3 = arg3), do we need to add place holder for arg2?
-        // What it should be?
-        val base =
-          if (ndarrayarg.argType.equals(returnType)) {
-            s"args += $currArgName"
-          } else if (ndarrayarg.argType.equals(s"Array[$returnType]")){
-            s"args ++= $currArgName"
-          } else {
-            "map(\"" + ndarrayarg.argName + "\") = " + currArgName
-          }
-        impl.append(
-          if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get"
-          else base
-        )
-      })
-      // add default out parameter
-      argDef += "out : Option[NDArray] = None"
-      impl += "if (!out.isEmpty) map(\"out\") = out.get"
-      // scalastyle:off
-      impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + + "\", args.toSeq, map.toMap)"
-      // scalastyle:on
-      // Combine and build the function string
-      val returnType = "org.apache.mxnet.NDArrayFuncReturn"
-      var finalStr = s"def ${}"
-      finalStr += s" (${argDef.mkString(",")}) : $returnType"
-      finalStr += s" = {${impl.mkString("\n")}}"
-      c.parse(finalStr).asInstanceOf[DefDef]
+    val functions = functionsToGenerate(isSymbol = false, isContrib)
+    val functionDefs = functions.flatMap { NDArrayfunction =>
+      val funcName =
+      val termName = TermName(funcName)
+      Seq(
+        // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
+        q"""
+             def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {
+               genericNDArrayFunctionInvoke($funcName, args, kwargs)
+             }
+          """.asInstanceOf[DefDef],
+        // e.g def transpose(args: Any*)
+        q"""
+             def $termName(args: Any*) = {
+               genericNDArrayFunctionInvoke($funcName, args, null)
+             }
+          """.asInstanceOf[DefDef]
+      )
-    structGeneration(c)(functionDefs, annottees : _*)
+    structGeneration(c)(functionDefs, annottees: _*)
-  private def structGeneration(c: blackbox.Context)
-                              (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*)
-  : c.Expr[Any] = {
+private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase {
+  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
-    val inputs =
-    // pattern match on the inputs
-    val modDefs = inputs map {
-      case ClassDef(mods, name, something, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ClassDef(mods, name, something, q)
-      case ModuleDef(mods, name, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ModuleDef(mods, name, q)
-      case ex =>
-        throw new IllegalArgumentException(s"Invalid macro input: $ex")
+    val isContrib: Boolean = c.prefix.tree match {
+      case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
-    // wrap the result up in an Expr, and return it
-    val result = c.Expr(Block(modDefs, Literal(Constant())))
-    result
+    val functions = typeSafeFunctionsToGenerate(isSymbol = false, isContrib)
+    val functionDefs = => buildTypedFunction(c)(f))
+    structGeneration(c)(functionDefs, annottees: _*)
+  protected def buildTypedFunction(c: blackbox.Context)
+                                  (function: Func): c.universe.DefDef = {
+    import c.universe._
+    val returnType = "org.apache.mxnet.NDArrayFuncReturn"
+    val ndarrayType = "org.apache.mxnet.NDArray"
+    // Construct argument field
+    val argDef = ListBuffer[String]()
+    argDef ++= typedFunctionCommonArgDef(function)
+    argDef += "out : Option[NDArray] = None"
-  // List and add all the atomic symbol functions to current module.
-  private def initNDArrayModule(): List[NDArrayFunction] = {
-    val opNames = ListBuffer.empty[String]
-    _LIB.mxListAllOpNames(opNames)
- => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeNDArrayFunction(opHandle.value, opName)
-    }).toList
-  }
+    // Construct Implementation field
+    var impl = ListBuffer[String]()
+    impl += "val map = scala.collection.mutable.Map[String, Any]()"
+    impl += s"val args = scala.collection.mutable.ArrayBuffer.empty[$ndarrayType]"
-  // Create an atomic symbol function by handle and function name.
-  private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
-  : NDArrayFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs)
-    val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) {
-      s"This function support variable length of positional input (${keyVarNumArgs.value})."
-    } else {
-      ""
-    }
-    val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})"
-    val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
-    // scalastyle:off println
-    if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
-      && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
-      println("NDArray function definition:\n" + docStr)
-    }
-    // scalastyle:on println
-    val argList = argNames zip argTypes map { case (argName, argType) =>
-      val typeAndOption =
-        CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.NDArray")
-      new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
+    // NDArray arg implementation
+    impl ++= { arg =>
+      if (arg.argType.equals(s"Array[$ndarrayType]")) {
+        s"args ++= ${arg.safeArgName}"
+      } else {
+        val base =
+          if (arg.argType.equals(ndarrayType)) {
+            // ndarrays go to args
+            s"args += ${arg.safeArgName}"
+          } else {
+            // other types go to kwargs
+            s"""map("${arg.argName}") = ${arg.safeArgName}"""
+          }
+        if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get"
+        else base
+      }
-    new NDArrayFunction(aliasName, argList.toList)
+    impl +=
+      s"""if (!out.isEmpty) map("out") = out.get
+         |org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
+         |  "${}", args.toSeq, map.toMap)
+       """.stripMargin
+    // Combine and build the function string
+    val finalStr =
+      s"""def ${}
+         |   (${argDef.mkString(",")}) : $returnType
+         | = {${impl.mkString("\n")}}
 Review comment:
   I suggest to discuss this in the next PR. Introducing the random module made refactor and split this further to ease the integration of module variants.

This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:

With regards,
Apache Git Services