You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by li...@apache.org on 2018/05/23 18:43:28 UTC

[incubator-mxnet] branch master updated: [MXNET-357] New Scala API Design (NDArray) (#10787)

This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b0d632f  [MXNET-357] New Scala API Design (NDArray) (#10787)
b0d632f is described below

commit b0d632f7ed9d59508e01e391fbe111ec5d1d2edd
Author: Lanking <la...@live.com>
AuthorDate: Wed May 23 11:43:20 2018 -0700

    [MXNET-357] New Scala API Design (NDArray) (#10787)
    
    * Add new NDArray APIs
    
    * Add NDArray APIs
    
    * change the impl into individual functions and add comments
    
    * Quick fix on redudant code
    
    * Change in Sync
---
 .../src/main/scala/org/apache/mxnet/NDArray.scala  |   2 +
 .../main/scala/org/apache/mxnet/NDArrayAPI.scala   |  24 +++
 .../main/scala/org/apache/mxnet/NDArrayMacro.scala | 195 +++++++++++++++++----
 3 files changed, 189 insertions(+), 32 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 416f2d7..469107a 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -37,6 +37,8 @@ object NDArray {
 
   private val functions: Map[String, NDArrayFunction] = initNDArrayModule()
 
+  val api = NDArrayAPI
+
   private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = {
     froms.foreach { from =>
       val weakRef = new WeakReference(from)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
new file mode 100644
index 0000000..d234ac6
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet
+@AddNDArrayAPIs(false)
+/**
+  * typesafe NDArray API: NDArray.api._
+  * Main code will be generated during compile time through Macros
+  */
+object NDArrayAPI {
+}
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index 036b9ec..bbe786f 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -29,18 +29,26 @@ private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnot
   private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.addDefs
 }
 
+private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation {
+  private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs
+}
+
 private[mxnet] object NDArrayMacro {
-  case class NDArrayFunction(handle: NDArrayHandle)
+  case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
+  case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
 
   // scalastyle:off havetype
   def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    impl(c)(false, annottees: _*)
+    impl(c)(annottees: _*)
+  }
+  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+    typeSafeAPIImpl(c)(annottees: _*)
   }
   // scalastyle:off havetype
 
-  private val ndarrayFunctions: Map[String, NDArrayFunction] = initNDArrayModule()
+  private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
 
-  private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = {
+  private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
 
     val isContrib: Boolean = c.prefix.tree match {
@@ -48,40 +56,99 @@ private[mxnet] object NDArrayMacro {
     }
 
     val newNDArrayFunctions = {
-      if (isContrib) ndarrayFunctions.filter(_._1.startsWith("_contrib_"))
-      else ndarrayFunctions.filter(!_._1.startsWith("_contrib_"))
+      if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
+      else ndarrayFunctions.filter(!_.name.startsWith("_contrib_"))
     }
 
-    val functionDefs = newNDArrayFunctions flatMap { case (funcName, funcProp) =>
-      val functionScope = {
-        if (isContrib) Modifiers()
-        else {
-          if (funcName.startsWith("_")) Modifiers(Flag.PRIVATE) else Modifiers()
+     val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
+        val funcName = NDArrayfunction.name
+        val termName = TermName(funcName)
+        if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) {
+          Seq(
+            // scalastyle:off
+            // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
+            q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
+            // e.g def transpose(args: Any*)
+            q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
+            // scalastyle:on
+          )
+        } else {
+          // Default private
+          Seq(
+            // scalastyle:off
+            q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
+            q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
+            // scalastyle:on
+          )
         }
       }
-      val newName = {
-        if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length())
-        else funcName
-      }
-      val termName = TermName(funcName)
-      // It will generate definition something like,
-      Seq(
-        // scalastyle:off
-        // def transpose(kwargs: Map[String, Any] = null)(args: Any*)
-        q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}",
-        // def transpose(args: Any*)
-        q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}"
-        // scalastyle:on
-      )
+
+    structGeneration(c)(functionDefs, annottees : _*)
+  }
+
+  private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
+    import c.universe._
+
+    val isContrib: Boolean = c.prefix.tree match {
+      case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
+    }
+    val newNDArrayFunctions = {
+      if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
+      else ndarrayFunctions.filter(!_.name.startsWith("_contrib_"))
+    }
+
+    val functionDefs = newNDArrayFunctions map { ndarrayfunction =>
+
+      // Construct argument field
+      var argDef = ListBuffer[String]()
+      // Construct Implementation field
+      var impl = ListBuffer[String]()
+      impl += "val map = scala.collection.mutable.Map[String, Any]()"
+      ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
+        // var is a special word used to define variable in Scala,
+        // need to changed to something else in order to make it work
+        val currArgName = ndarrayarg.argName match {
+          case "var" => "vari"
+          case "type" => "typeOf"
+          case default => ndarrayarg.argName
+        }
+        if (ndarrayarg.isOptional) {
+          argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
+        }
+        else {
+          argDef += s"${currArgName} : ${ndarrayarg.argType}"
+        }
+        var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName
+        if (ndarrayarg.isOptional) {
+          base = "if (!" + currArgName + ".isEmpty)" + base + ".get"
+        }
+        impl += base
+      })
+      // scalastyle:off
+      impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)"
+      // scalastyle:on
+      // Combine and build the function string
+      val returnType = "org.apache.mxnet.NDArray"
+      var finalStr = s"def ${ndarrayfunction.name}New"
+      finalStr += s" (${argDef.mkString(",")}) : $returnType"
+      finalStr += s" = {${impl.mkString("\n")}}"
+      c.parse(finalStr).asInstanceOf[DefDef]
     }
 
+    structGeneration(c)(functionDefs, annottees : _*)
+  }
+
+  private def structGeneration(c: blackbox.Context)
+                              (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*)
+  : c.Expr[Any] = {
+    import c.universe._
     val inputs = annottees.map(_.tree).toList
     // pattern match on the inputs
     val modDefs = inputs map {
       case ClassDef(mods, name, something, template) =>
         val q = template match {
           case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ functionDefs)
+            Template(superMaybe, emptyValDef, defs ++ funcDef)
           case ex =>
             throw new IllegalArgumentException(s"Invalid template: $ex")
         }
@@ -89,7 +156,7 @@ private[mxnet] object NDArrayMacro {
       case ModuleDef(mods, name, template) =>
         val q = template match {
           case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ functionDefs)
+            Template(superMaybe, emptyValDef, defs ++ funcDef)
           case ex =>
             throw new IllegalArgumentException(s"Invalid template: $ex")
         }
@@ -102,20 +169,80 @@ private[mxnet] object NDArrayMacro {
     result
   }
 
+
+  // Convert C++ Types to Scala Types
+  private def typeConversion(in : String, argType : String = "") : String = {
+    in match {
+      case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
+      case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.NDArray"
+      case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
+      => "Array[org.apache.mxnet.NDArray]"
+      case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat"
+      case "int" | "intorNone" | "int(non-negative)" => "Int"
+      case "long" | "long(non-negative)" => "Long"
+      case "double" | "doubleorNone" => "Double"
+      case "string" => "String"
+      case "boolean" | "booleanorNone" => "Boolean"
+      case "tupleof<float>" | "tupleof<double>" | "ptr" | "" => "Any"
+      case default => throw new IllegalArgumentException(
+        s"Invalid type for args: $default, $argType")
+    }
+  }
+
+
+  /**
+    * By default, the argType come from the C++ API is a description more than a single word
+    * For Example:
+    *   <C++ Type>, <Required/Optional>, <Default=>
+    * The three field shown above do not usually come at the same time
+    * This function used the above format to determine if the argument is
+    * optional, what is it Scala type and possibly pass in a default value
+    * @param argType Raw arguement Type description
+    * @return (Scala_Type, isOptional)
+    */
+  private def argumentCleaner(argType : String) : (String, Boolean) = {
+    val spaceRemoved = argType.replaceAll("\\s+", "")
+    var commaRemoved : Array[String] = new Array[String](0)
+    // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
+    if (spaceRemoved.charAt(0)== '{') {
+      val endIdx = spaceRemoved.indexOf('}')
+      commaRemoved = spaceRemoved.substring(endIdx + 1).split(",")
+      commaRemoved(0) = "string"
+    } else {
+      commaRemoved = spaceRemoved.split(",")
+    }
+    // Optional Field
+    if (commaRemoved.length >= 3) {
+      // arg: Type, optional, default = Null
+      require(commaRemoved(1).equals("optional"))
+      require(commaRemoved(2).startsWith("default="))
+      (typeConversion(commaRemoved(0), argType), true)
+    } else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
+      val tempType = typeConversion(commaRemoved(0), argType)
+      val tempOptional = tempType.equals("org.apache.mxnet.NDArray")
+      (tempType, tempOptional)
+    } else {
+      throw new IllegalArgumentException(
+        s"Unrecognized arg field: $argType, ${commaRemoved.length}")
+    }
+
+  }
+
+
   // List and add all the atomic symbol functions to current module.
-  private def initNDArrayModule(): Map[String, NDArrayFunction] = {
+  private def initNDArrayModule(): List[NDArrayFunction] = {
     val opNames = ListBuffer.empty[String]
     _LIB.mxListAllOpNames(opNames)
     opNames.map(opName => {
       val opHandle = new RefLong
       _LIB.nnGetOpHandle(opName, opHandle)
       makeNDArrayFunction(opHandle.value, opName)
-    }).toMap
+    }).toList
   }
 
   // Create an atomic symbol function by handle and function name.
   private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
-    : (String, NDArrayFunction) = {
+  : NDArrayFunction = {
     val name = new RefString
     val desc = new RefString
     val keyVarNumArgs = new RefString
@@ -136,10 +263,14 @@ private[mxnet] object NDArrayMacro {
     val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
     // scalastyle:off println
     if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
-          && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
+      && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
       println("NDArray function definition:\n" + docStr)
     }
     // scalastyle:on println
-    (aliasName, new NDArrayFunction(handle))
+    val argList = argNames zip argTypes map { case (argName, argType) =>
+      val typeAndOption = argumentCleaner(argType)
+      new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
+    }
+    new NDArrayFunction(aliasName, argList.toList)
   }
 }

-- 
To stop receiving notification emails like this one, please contact
liuyizhi@apache.org.