You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/10/26 18:53:51 UTC
[incubator-mxnet] branch java-api updated: [MXNET-984] Java NDArray
Documentation Generation (#12835)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch java-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/java-api by this push:
new 5aaa729 [MXNET-984] Java NDArray Documentation Generation (#12835)
5aaa729 is described below
commit 5aaa72998e180e56d4b21c90d8791928661754c3
Author: Lanking <la...@live.com>
AuthorDate: Fri Oct 26 11:53:34 2018 -0700
[MXNET-984] Java NDArray Documentation Generation (#12835)
* cherry pick javaDoc changes
* update NDArray changes
* refactoring change and merge all docGen in a single place
* clean the scalastyle
* take on Piyush nit
* drop the comments
---
.../scala/org/apache/mxnet/javaapi/NDArray.scala | 2 +-
.../scala/org/apache/mxnet/APIDocGenerator.scala | 151 ++++++++++++++++-----
.../apache/mxnet/javaapi/JavaNDArrayMacro.scala | 6 +-
.../org/apache/mxnet/utils/CToScalaUtils.scala | 9 +-
4 files changed, 124 insertions(+), 44 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index c77b440..96119be 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -22,7 +22,7 @@ import org.apache.mxnet.javaapi.DType.DType
import collection.JavaConverters._
@AddJNDArrayAPIs(false)
-object NDArray {
+object NDArray extends NDArrayBase {
implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd)
implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index b4efa65..44d47a2 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -42,6 +42,8 @@ private[mxnet] object APIDocGenerator{
hashCollector += absClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
+ // Generate Java API documentation
+ hashCollector += javaClassGen(FILE_PATH + "javaapi/")
val finalHash = hashCollector.mkString("\n")
}
@@ -52,8 +54,45 @@ private[mxnet] object APIDocGenerator{
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}
- def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
- // scalastyle:off
+ def fileGen(filePath : String, packageName : String, packageDef : String,
+ absFuncs : List[String]) : String = {
+ val apacheLicense =
+ """/*
+ |* Licensed to the Apache Software Foundation (ASF) under one or more
+ |* contributor license agreements. See the NOTICE file distributed with
+ |* this work for additional information regarding copyright ownership.
+ |* The ASF licenses this file to You under the Apache License, Version 2.0
+ |* (the "License"); you may not use this file except in compliance with
+ |* the License. You may obtain a copy of the License at
+ |*
+ |* http://www.apache.org/licenses/LICENSE-2.0
+ |*
+ |* Unless required by applicable law or agreed to in writing, software
+ |* distributed under the License is distributed on an "AS IS" BASIS,
+ |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ |* See the License for the specific language governing permissions and
+ |* limitations under the License.
+ |*/
+ |""".stripMargin
+ val scalaStyle = "// scalastyle:off"
+ val imports = "import org.apache.mxnet.annotation.Experimental"
+ val absClassDef = s"abstract class $packageName"
+
+ val finalStr =
+ s"""$apacheLicense
+ |$scalaStyle
+ |$packageDef
+ |$imports
+ |$absClassDef {
+ |${absFuncs.mkString("\n")}
+ |}""".stripMargin
+ val pw = new PrintWriter(new File(filePath + s"$packageName.scala"))
+ pw.write(finalStr)
+ pw.close()
+ MD5Generator(finalStr)
+ }
+
+ def absClassGen(filePath : String, isSymbol : Boolean) : String = {
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// Defines Operators that should not generated
val notGenerated = Set("Custom")
@@ -66,19 +105,27 @@ private[mxnet] object APIDocGenerator{
s"$scalaDoc\n$defBody"
})
val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
- val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICE [...]
- val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
- val imports = "import org.apache.mxnet.annotation.Experimental"
- val absClassDef = s"abstract class $packageName"
- val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
- val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
- pw.write(finalStr)
- pw.close()
- MD5Generator(finalStr)
+ fileGen(filePath, packageName, packageDef, absFuncs)
+ }
+
+ def javaClassGen(filePath : String) : String = {
+ val notGenerated = Set("Custom")
+ val absClassFunctions = getSymbolNDArrayMethods(false, true)
+ // TODO: Add Filter to the same location in case of refactor
+ val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
+ .filterNot(ele => notGenerated.contains(ele.name))
+ .map(absClassFunction => {
+ val scalaDoc = generateAPIDocFromBackend(absClassFunction)
+ val defBody = generateJavaAPISignature(absClassFunction)
+ s"$scalaDoc\n$defBody"
+ })
+ val packageName = "NDArrayBase"
+ val packageDef = "package org.apache.mxnet.javaapi"
+ fileGen(filePath, packageName, packageDef, absFuncs)
}
- def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
+ def nonTypeSafeClassGen(filePath : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
val absFuncs = absClassFunctions.map(absClassFunction => {
@@ -93,17 +140,23 @@ private[mxnet] object APIDocGenerator{
}
})
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
- val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICE [...]
- val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
- val imports = "import org.apache.mxnet.annotation.Experimental"
- val absClassDef = s"abstract class $packageName"
- val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
- import java.io._
- val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
- pw.write(finalStr)
- pw.close()
- MD5Generator(finalStr)
+ fileGen(filePath, packageName, packageDef, absFuncs)
+ }
+
+ /**
+ * Some of the C++ type name is not valid in Scala
+ * such as var and type. This method is to convert
+ * them into other names to get it passed
+ * @param in the input String
+ * @return converted name string
+ */
+ def safetyNameCheck(in : String) : String = {
+ in match {
+ case "var" => "vari"
+ case "type" => "typeOf"
+ case _ => in
+ }
}
// Generate ScalaDoc type
@@ -115,11 +168,7 @@ private[mxnet] object APIDocGenerator{
})
desc += " * </pre>"
val params = func.listOfArgs.map({ absClassArg =>
- val currArgName = absClassArg.argName match {
- case "var" => "vari"
- case "type" => "typeOf"
- case _ => absClassArg.argName
- }
+ val currArgName = safetyNameCheck(absClassArg.argName)
s" * @param $currArgName\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"
@@ -133,11 +182,7 @@ private[mxnet] object APIDocGenerator{
def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
var argDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
- val currArgName = absClassArg.argName match {
- case "var" => "vari"
- case "type" => "typeOf"
- case _ => absClassArg.argName
- }
+ val currArgName = safetyNameCheck(absClassArg.argName)
if (absClassArg.isOptional) {
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
}
@@ -157,23 +202,57 @@ private[mxnet] object APIDocGenerator{
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
}
+ def generateJavaAPISignature(func : absClassFunction) : String = {
+ var argDef = ListBuffer[String]()
+ var classDef = ListBuffer[String]()
+ func.listOfArgs.foreach(absClassArg => {
+ val currArgName = safetyNameCheck(absClassArg.argName)
+ // scalastyle:off
+ if (absClassArg.isOptional) {
+ classDef += s"def set${absClassArg.argName}(${absClassArg.argName} : ${absClassArg.argType}) : ${func.name}BuilderBase"
+ }
+ else {
+ argDef += s"$currArgName : ${absClassArg.argType}"
+ }
+ // scalastyle:on
+ })
+ classDef += s"def setout(out : NDArray) : ${func.name}BuilderBase"
+ classDef += s"def invoke() : org.apache.mxnet.javaapi.NDArrayFuncReturn"
+ val experimentalTag = "@Experimental"
+ // scalastyle:off
+ var finalStr = s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : ${func.name}BuilderBase\n"
+ // scalastyle:on
+ finalStr += s"abstract class ${func.name}BuilderBase {\n ${classDef.mkString("\n ")}\n}"
+ finalStr
+ }
+
// List and add all the atomic symbol functions to current module.
- private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
+ private def getSymbolNDArrayMethods(isSymbol : Boolean,
+ isJava : Boolean = false): List[absClassFunction] = {
val opNames = ListBuffer.empty[String]
val returnType = if (isSymbol) "Symbol" else "NDArray"
+ val returnHeader = if (isJava) "org.apache.mxnet.javaapi." else "org.apache.mxnet."
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
// TODO: Add Filter to the same location in case of refactor
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
- makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
- }).toList.filterNot(_.name.startsWith("_"))
+ makeAtomicSymbolFunction(opHandle.value, opName, returnHeader + returnType)
+ }).filterNot(_.name.startsWith("_")).groupBy(_.name.toLowerCase).map(ele => {
+ // Pattern matching for not generating depreciated method
+ if (ele._2.length == 1) ele._2.head
+ else {
+ if (ele._2.head.name.head.isLower) ele._2.head
+ else ele._2.last
+ }
+ }).toList
}
// Create an atomic symbol function by handle and function name.
- private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
+ private def makeAtomicSymbolFunction(handle: SymbolHandle,
+ aliasName: String, returnType : String)
: absClassFunction = {
val name = new RefString
val desc = new RefString
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
index c530c73..d5be97b 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
@@ -120,12 +120,12 @@ private[mxnet] object JavaNDArrayMacro {
// scalastyle:off
// Combine and build the function string
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
- val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")})"
+ val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")}) extends ${ndarrayfunction.name}BuilderBase"
val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}"
val classFinal = s"$classDef {$classBody}"
val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})"
val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})"
- val functionFinal = s"$functionDef = $functionBody"
+ val functionFinal = s"$functionDef : ${ndarrayfunction.name}BuilderBase = $functionBody"
// scalastyle:on
functionDefs += c.parse(functionFinal).asInstanceOf[DefDef]
classDefs += c.parse(classFinal).asInstanceOf[ClassDef]
@@ -195,7 +195,7 @@ private[mxnet] object JavaNDArrayMacro {
val argList = argNames zip argTypes map { case (argName, argType) =>
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType,
- "org.apache.mxnet.javaapi.NDArray", "javaapi.Shape")
+ "org.apache.mxnet.javaapi.NDArray")
new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
}
new NDArrayFunction(aliasName, argList.toList)
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index 48d8fdf..2fd8b2e 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -22,9 +22,10 @@ private[mxnet] object CToScalaUtils {
// Convert C++ Types to Scala Types
def typeConversion(in : String, argType : String = "", argName : String,
- returnType : String, shapeType : String = "Shape") : String = {
+ returnType : String) : String = {
+ val header = returnType.split("\\.").dropRight(1)
in match {
- case "Shape(tuple)" | "ShapeorNone" => s"org.apache.mxnet.$shapeType"
+ case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
=> s"Array[$returnType]"
@@ -53,7 +54,7 @@ private[mxnet] object CToScalaUtils {
* @return (Scala_Type, isOptional)
*/
def argumentCleaner(argName: String, argType : String,
- returnType : String, shapeType : String = "Shape") : (String, Boolean) = {
+ returnType : String) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
@@ -73,7 +74,7 @@ private[mxnet] object CToScalaUtils {
s"""expected "default=..." got ${commaRemoved(2)}""")
(typeConversion(commaRemoved(0), argType, argName, returnType), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
- val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, shapeType)
+ val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
(tempType, tempOptional)
} else {