You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/22 08:05:23 UTC

[GitHub] nswamy commented on a change in pull request #10787: [MXNET-357] New Scala API Design (NDArray)

nswamy commented on a change in pull request #10787: [MXNET-357] New Scala API Design (NDArray)
URL: https://github.com/apache/incubator-mxnet/pull/10787#discussion_r189806243
 
 

 ##########
 File path: scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
 ##########
 @@ -29,67 +29,143 @@ private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnot
   private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.addDefs
 }
 
+private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation {
+  private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs
+}
+
 private[mxnet] object NDArrayMacro {
-  case class NDArrayFunction(handle: NDArrayHandle)
+  case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
+  case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
 
   // scalastyle:off havetype
   def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    impl(c)(false, annottees: _*)
+    impl(c)(annottees: _*)
+  }
+  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+    typeSafeAPIImpl(c)(annottees: _*)
   }
   // scalastyle:off havetype
 
-  private val ndarrayFunctions: Map[String, NDArrayFunction] = initNDArrayModule()
+  private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
 
-  private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = {
+  private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
 
     val isContrib: Boolean = c.prefix.tree match {
       case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b))
     }
 
     val newNDArrayFunctions = {
-      if (isContrib) ndarrayFunctions.filter(_._1.startsWith("_contrib_"))
-      else ndarrayFunctions.filter(!_._1.startsWith("_contrib_"))
+      if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
+      else ndarrayFunctions.filter(!_.name.startsWith("_contrib_"))
     }
 
-    val functionDefs = newNDArrayFunctions flatMap { case (funcName, funcProp) =>
-      val functionScope = {
-        if (isContrib) Modifiers()
-        else {
-          if (funcName.startsWith("_")) Modifiers(Flag.PRIVATE) else Modifiers()
+     val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
+        val funcName = NDArrayfunction.name
+        val termName = TermName(funcName)
+        if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) {
+          Seq(
+            // scalastyle:off
+            // def transpose(kwargs: Map[String, Any] = null)(args: Any*)
+            q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
+            // def transpose(args: Any*)
+            q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
+            // scalastyle:on
+          )
+        } else {
+          // Default private
+          Seq(
+            // scalastyle:off
+            q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
+            q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
+            // scalastyle:on
+          )
         }
       }
-      val newName = {
-        if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length())
-        else funcName
-      }
-      val termName = TermName(funcName)
-      // It will generate definition something like,
-      Seq(
-        // scalastyle:off
-        // def transpose(kwargs: Map[String, Any] = null)(args: Any*)
-        q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}",
-        // def transpose(args: Any*)
-        q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}"
-        // scalastyle:on
-      )
+
+    structGeneration(c)(functionDefs, annottees : _*)
+  }
+
+  private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
+    import c.universe._
+
+    val isContrib: Boolean = c.prefix.tree match {
+      case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
+    }
+    val newNDArrayFunctions = {
+      if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
+      else ndarrayFunctions.filter(!_.name.startsWith("_contrib_"))
     }
 
+    val functionDefs = newNDArrayFunctions map { ndarrayfunction =>
+
+      // Construct argument field
+      var argDef = ListBuffer[String]()
+      ndarrayfunction.listOfArgs.foreach(ndarrayarg => {
+        val currArgName = ndarrayarg.argName match {
+          case "var" => "vari"
 
 Review comment:
   please add comments why you are transforming these.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services