You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/10/26 18:53:51 UTC

[incubator-mxnet] branch java-api updated: [MXNET-984] Java NDArray Documentation Generation (#12835)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch java-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/java-api by this push:
     new 5aaa729  [MXNET-984] Java NDArray Documentation Generation (#12835)
5aaa729 is described below

commit 5aaa72998e180e56d4b21c90d8791928661754c3
Author: Lanking <la...@live.com>
AuthorDate: Fri Oct 26 11:53:34 2018 -0700

    [MXNET-984] Java NDArray Documentation Generation (#12835)
    
    * cherry pick javaDoc changes
    
    * update NDArray changes
    
    * refactoring change and merge all docGen in a single place
    
    * clean the scalastyle
    
    * take on Piyush nit
    
    * drop the comments
---
 .../scala/org/apache/mxnet/javaapi/NDArray.scala   |   2 +-
 .../scala/org/apache/mxnet/APIDocGenerator.scala   | 151 ++++++++++++++++-----
 .../apache/mxnet/javaapi/JavaNDArrayMacro.scala    |   6 +-
 .../org/apache/mxnet/utils/CToScalaUtils.scala     |   9 +-
 4 files changed, 124 insertions(+), 44 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index c77b440..96119be 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -22,7 +22,7 @@ import org.apache.mxnet.javaapi.DType.DType
 import collection.JavaConverters._
 
 @AddJNDArrayAPIs(false)
-object NDArray {
+object NDArray extends NDArrayBase {
   implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd)
 
   implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index b4efa65..44d47a2 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -42,6 +42,8 @@ private[mxnet] object APIDocGenerator{
     hashCollector += absClassGen(FILE_PATH, false)
     hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
     hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
+    // Generate Java API documentation
+    hashCollector += javaClassGen(FILE_PATH + "javaapi/")
     val finalHash = hashCollector.mkString("\n")
   }
 
@@ -52,8 +54,45 @@ private[mxnet] object APIDocGenerator{
     org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
   }
 
-  def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
-    // scalastyle:off
+  def fileGen(filePath : String, packageName : String, packageDef : String,
+              absFuncs : List[String]) : String = {
+    val apacheLicense =
+      """/*
+        |* Licensed to the Apache Software Foundation (ASF) under one or more
+        |* contributor license agreements.  See the NOTICE file distributed with
+        |* this work for additional information regarding copyright ownership.
+        |* The ASF licenses this file to You under the Apache License, Version 2.0
+        |* (the "License"); you may not use this file except in compliance with
+        |* the License.  You may obtain a copy of the License at
+        |*
+        |*    http://www.apache.org/licenses/LICENSE-2.0
+        |*
+        |* Unless required by applicable law or agreed to in writing, software
+        |* distributed under the License is distributed on an "AS IS" BASIS,
+        |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+        |* See the License for the specific language governing permissions and
+        |* limitations under the License.
+        |*/
+        |""".stripMargin
+    val scalaStyle = "// scalastyle:off"
+    val imports = "import org.apache.mxnet.annotation.Experimental"
+    val absClassDef = s"abstract class $packageName"
+
+    val finalStr =
+      s"""$apacheLicense
+         |$scalaStyle
+         |$packageDef
+         |$imports
+         |$absClassDef {
+         |${absFuncs.mkString("\n")}
+         |}""".stripMargin
+    val pw = new PrintWriter(new File(filePath + s"$packageName.scala"))
+    pw.write(finalStr)
+    pw.close()
+    MD5Generator(finalStr)
+  }
+
+  def absClassGen(filePath : String, isSymbol : Boolean) : String = {
     val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
     // Defines Operators that should not generated
     val notGenerated = Set("Custom")
@@ -66,19 +105,27 @@ private[mxnet] object APIDocGenerator{
       s"$scalaDoc\n$defBody"
     })
     val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
-    val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements.  See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License.  You may obtain a copy of the License at\n*\n*    http://www.apache.org/licenses/LICE [...]
-    val scalaStyle = "// scalastyle:off"
     val packageDef = "package org.apache.mxnet"
-    val imports = "import org.apache.mxnet.annotation.Experimental"
-    val absClassDef = s"abstract class $packageName"
-    val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
-    val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
-    pw.write(finalStr)
-    pw.close()
-    MD5Generator(finalStr)
+    fileGen(filePath, packageName, packageDef, absFuncs)
+  }
+
+  def javaClassGen(filePath : String) : String = {
+    val notGenerated = Set("Custom")
+    val absClassFunctions = getSymbolNDArrayMethods(false, true)
+    // TODO: Add Filter to the same location in case of refactor
+    val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
+      .filterNot(ele => notGenerated.contains(ele.name))
+      .map(absClassFunction => {
+        val scalaDoc = generateAPIDocFromBackend(absClassFunction)
+        val defBody = generateJavaAPISignature(absClassFunction)
+        s"$scalaDoc\n$defBody"
+      })
+    val packageName = "NDArrayBase"
+    val packageDef = "package org.apache.mxnet.javaapi"
+    fileGen(filePath, packageName, packageDef, absFuncs)
   }
 
-  def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
+  def nonTypeSafeClassGen(filePath : String, isSymbol : Boolean) : String = {
     // scalastyle:off
     val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
     val absFuncs = absClassFunctions.map(absClassFunction => {
@@ -93,17 +140,23 @@ private[mxnet] object APIDocGenerator{
       }
     })
     val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
-    val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements.  See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License.  You may obtain a copy of the License at\n*\n*    http://www.apache.org/licenses/LICE [...]
-    val scalaStyle = "// scalastyle:off"
     val packageDef = "package org.apache.mxnet"
-    val imports = "import org.apache.mxnet.annotation.Experimental"
-    val absClassDef = s"abstract class $packageName"
-    val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
-    import java.io._
-    val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
-    pw.write(finalStr)
-    pw.close()
-    MD5Generator(finalStr)
+    fileGen(filePath, packageName, packageDef, absFuncs)
+  }
+
+  /**
+    * Some of the C++ type name is not valid in Scala
+    * such as var and type. This method is to convert
+    * them into other names to get it passed
+    * @param in the input String
+    * @return converted name string
+    */
+  def safetyNameCheck(in : String) : String = {
+    in match {
+      case "var" => "vari"
+      case "type" => "typeOf"
+      case _ => in
+    }
   }
 
   // Generate ScalaDoc type
@@ -115,11 +168,7 @@ private[mxnet] object APIDocGenerator{
     })
     desc += "  * </pre>"
     val params = func.listOfArgs.map({ absClassArg =>
-      val currArgName = absClassArg.argName match {
-                case "var" => "vari"
-                case "type" => "typeOf"
-                case _ => absClassArg.argName
-      }
+      val currArgName = safetyNameCheck(absClassArg.argName)
       s"  * @param $currArgName\t\t${absClassArg.argDesc}"
     })
     val returnType = s"  * @return ${func.returnType}"
@@ -133,11 +182,7 @@ private[mxnet] object APIDocGenerator{
   def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
     var argDef = ListBuffer[String]()
     func.listOfArgs.foreach(absClassArg => {
-      val currArgName = absClassArg.argName match {
-        case "var" => "vari"
-        case "type" => "typeOf"
-        case _ => absClassArg.argName
-      }
+      val currArgName = safetyNameCheck(absClassArg.argName)
       if (absClassArg.isOptional) {
         argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
       }
@@ -157,23 +202,57 @@ private[mxnet] object APIDocGenerator{
     s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
   }
 
+  def generateJavaAPISignature(func : absClassFunction) : String = {
+    var argDef = ListBuffer[String]()
+    var classDef = ListBuffer[String]()
+    func.listOfArgs.foreach(absClassArg => {
+      val currArgName = safetyNameCheck(absClassArg.argName)
+      // scalastyle:off
+      if (absClassArg.isOptional) {
+        classDef += s"def set${absClassArg.argName}(${absClassArg.argName} : ${absClassArg.argType}) : ${func.name}BuilderBase"
+      }
+      else {
+        argDef += s"$currArgName : ${absClassArg.argType}"
+      }
+      // scalastyle:on
+    })
+    classDef += s"def setout(out : NDArray) : ${func.name}BuilderBase"
+    classDef += s"def invoke() : org.apache.mxnet.javaapi.NDArrayFuncReturn"
+    val experimentalTag = "@Experimental"
+    // scalastyle:off
+    var finalStr = s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : ${func.name}BuilderBase\n"
+    // scalastyle:on
+    finalStr += s"abstract class ${func.name}BuilderBase {\n  ${classDef.mkString("\n  ")}\n}"
+    finalStr
+  }
+
 
   // List and add all the atomic symbol functions to current module.
-  private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
+  private def getSymbolNDArrayMethods(isSymbol : Boolean,
+                                      isJava : Boolean = false): List[absClassFunction] = {
     val opNames = ListBuffer.empty[String]
     val returnType = if (isSymbol) "Symbol" else "NDArray"
+    val returnHeader = if (isJava) "org.apache.mxnet.javaapi." else "org.apache.mxnet."
     _LIB.mxListAllOpNames(opNames)
     // TODO: Add '_linalg_', '_sparse_', '_image_' support
     // TODO: Add Filter to the same location in case of refactor
     opNames.map(opName => {
       val opHandle = new RefLong
       _LIB.nnGetOpHandle(opName, opHandle)
-      makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
-    }).toList.filterNot(_.name.startsWith("_"))
+      makeAtomicSymbolFunction(opHandle.value, opName, returnHeader + returnType)
+    }).filterNot(_.name.startsWith("_")).groupBy(_.name.toLowerCase).map(ele => {
+      // Pattern matching for not generating depreciated method
+      if (ele._2.length == 1) ele._2.head
+      else {
+        if (ele._2.head.name.head.isLower) ele._2.head
+        else ele._2.last
+      }
+    }).toList
   }
 
   // Create an atomic symbol function by handle and function name.
-  private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
+  private def makeAtomicSymbolFunction(handle: SymbolHandle,
+                                       aliasName: String, returnType : String)
   : absClassFunction = {
     val name = new RefString
     val desc = new RefString
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
index c530c73..d5be97b 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
@@ -120,12 +120,12 @@ private[mxnet] object JavaNDArrayMacro {
       // scalastyle:off
       // Combine and build the function string
       impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
-      val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")})"
+      val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")}) extends ${ndarrayfunction.name}BuilderBase"
       val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}"
       val classFinal = s"$classDef {$classBody}"
       val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})"
       val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})"
-      val functionFinal = s"$functionDef = $functionBody"
+      val functionFinal = s"$functionDef : ${ndarrayfunction.name}BuilderBase = $functionBody"
       // scalastyle:on
       functionDefs += c.parse(functionFinal).asInstanceOf[DefDef]
       classDefs += c.parse(classFinal).asInstanceOf[ClassDef]
@@ -195,7 +195,7 @@ private[mxnet] object JavaNDArrayMacro {
     val argList = argNames zip argTypes map { case (argName, argType) =>
       val typeAndOption =
         CToScalaUtils.argumentCleaner(argName, argType,
-          "org.apache.mxnet.javaapi.NDArray", "javaapi.Shape")
+          "org.apache.mxnet.javaapi.NDArray")
       new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
     }
     new NDArrayFunction(aliasName, argList.toList)
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index 48d8fdf..2fd8b2e 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -22,9 +22,10 @@ private[mxnet] object CToScalaUtils {
 
   // Convert C++ Types to Scala Types
   def typeConversion(in : String, argType : String = "", argName : String,
-                     returnType : String, shapeType : String = "Shape") : String = {
+                     returnType : String) : String = {
+    val header = returnType.split("\\.").dropRight(1)
     in match {
-      case "Shape(tuple)" | "ShapeorNone" => s"org.apache.mxnet.$shapeType"
+      case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape"
       case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
       case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
       => s"Array[$returnType]"
@@ -53,7 +54,7 @@ private[mxnet] object CToScalaUtils {
     * @return (Scala_Type, isOptional)
     */
   def argumentCleaner(argName: String, argType : String,
-                      returnType : String, shapeType : String = "Shape") : (String, Boolean) = {
+                      returnType : String) : (String, Boolean) = {
     val spaceRemoved = argType.replaceAll("\\s+", "")
     var commaRemoved : Array[String] = new Array[String](0)
     // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
@@ -73,7 +74,7 @@ private[mxnet] object CToScalaUtils {
         s"""expected "default=..." got ${commaRemoved(2)}""")
       (typeConversion(commaRemoved(0), argType, argName, returnType), true)
     } else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
-      val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, shapeType)
+      val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
       val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
       (tempType, tempOptional)
     } else {