You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/30 14:41:37 UTC

[GitHub] mdespriee closed pull request #12489: [WIP][MXNET-918] Random api

mdespriee closed pull request #12489: [WIP][MXNET-918] Random api
URL: https://github.com/apache/incubator-mxnet/pull/12489
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index f9f2dbe42a9..83712590867 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -40,6 +40,8 @@ object NDArray extends NDArrayBase {
 
   val api = NDArrayAPI
 
+  val random = NDArrayRandomAPI
+
   private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = {
     froms.foreach { from =>
       val weakRef = new WeakReference(from)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
index 1d8551c1b1e..cb507877520 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
@@ -23,3 +23,13 @@ package org.apache.mxnet
 object NDArrayAPI extends NDArrayAPIBase {
   // TODO: Implement CustomOp for NDArray
 }
+
+@AddNDArrayRandomAPIs(false)
+/**
+  * typesafe NDArray random module: NDArray.random._
+  * Main code will be generated during compile time through Macros
+  */
+object NDArrayRandomAPI extends NDArrayRandomAPIBase {
+
+}
+
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 4472a8426f9..17f3636c2e2 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -840,8 +840,12 @@ object Symbol extends SymbolBase {
   private val functions: Map[String, SymbolFunction] = initSymbolModule()
   private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)
 
+  type SymbolOrFloat = Any
+
   val api = SymbolAPI
 
+  val random = SymbolRandomAPI
+
   def pow(sym1: Symbol, sym2: Symbol): Symbol = {
     Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
   }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
index 1bfb0559cf9..e5b55f6e1ee 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
@@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase {
     Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap)
   }
 }
+
+@AddSymbolRandomAPIs(false)
+/**
+  * typesafe Symbol random module: Symbol.random._
+  * Main code will be generated during compile time through Macros
+  */
+object SymbolRandomAPI extends SymbolRandomAPIBase {
+
+}
+
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 5d88bb39e50..f90487739f3 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -576,4 +576,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     assert(arr.internal.toDoubleArray === Array(2d, 2d))
     assert(arr.internal.toByteArray === Array(2.toByte, 2.toByte))
   }
+
+  test("random module is generated properly") {
+    val lam = NDArray.ones(1, 2)
+    val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 4)))
+    val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)))
+    assert(rnd.shape === Shape(1, 2, 3, 4))
+    assert(rnd2.shape === Shape(3, 4))
+  }
+
+  test("random module is generated properly - special case of 'normal'") {
+    val mu = NDArray.ones(1, 2)
+    val sigma = NDArray.ones(1, 2) * 2
+    val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape = Some(Shape(3, 4)))
+    val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f),
+      shape = Some(Shape(3, 4)))
+    assert(rnd.shape === Shape(1, 2, 3, 4))
+    assert(rnd2.shape === Shape(3, 4))
+  }
 }
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
index ebb61d7d4bf..4ca8a456445 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
@@ -19,7 +19,9 @@ package org.apache.mxnet
 
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 
+
 class SymbolSuite extends FunSuite with BeforeAndAfterAll {
+
   test("symbol compose") {
     val data = Symbol.Variable("data")
 
@@ -71,4 +73,27 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
     val data2 = data.clone()
     assert(data.toJson === data2.toJson)
   }
+
+  test("random module is generated properly") {
+    val lam = Symbol.Variable("lam")
+    val rnd = Symbol.random.poisson(lam = Some(lam), shape = Some(Shape(2, 2)))
+    val rnd2 = Symbol.random.poisson(lam = Some(1f), shape = Some(Shape(2, 2)))
+    // scalastyle:off println
+    println(s"Symbol.random.poisson debug info: ${rnd.debugStr}")
+    println(s"Symbol.random.poisson debug info: ${rnd2.debugStr}")
+    // scalastyle:on println
+  }
+
+  test("random module is generated properly - special case of 'normal'") {
+    val loc = Symbol.Variable("loc")
+    val scale = Symbol.Variable("scale")
+    val rnd = Symbol.random.normal(mu = Some(loc), sigma = Some(scale),
+      shape = Some(Shape(2, 2)))
+    val rnd2 = Symbol.random.normal(mu = Some(1f), sigma = Some(2f),
+      shape = Some(Shape(2, 2)))
+    // scalastyle:off println
+    println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}")
+    println(s"Symbol.random.random_normal debug info: ${rnd2.debugStr}")
+    // scalastyle:on println
+  }
 }
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index b4efa659443..8dbab4a984d 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -17,8 +17,6 @@
 
 package org.apache.mxnet
 
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.CToScalaUtils
 import java.io._
 import java.security.MessageDigest
 
@@ -29,77 +27,115 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
   * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
   * The code will be executed during Macros stage and file live in Core stage
   */
-private[mxnet] object APIDocGenerator{
-  case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean)
-  case class absClassFunction(name : String, desc : String,
-                           listOfArgs: List[absClassArg], returnType : String)
+private[mxnet] object APIDocGenerator extends GeneratorBase {
+  type absClassArg = Arg
+  type absClassFunction = Func
 
-
-  def main(args: Array[String]) : Unit = {
+  def main(args: Array[String]): Unit = {
     val FILE_PATH = args(0)
     val hashCollector = ListBuffer[String]()
     hashCollector += absClassGen(FILE_PATH, true)
     hashCollector += absClassGen(FILE_PATH, false)
+    hashCollector += absRndClassGen(FILE_PATH, true)
+    hashCollector += absRndClassGen(FILE_PATH, false)
     hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
     hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
     val finalHash = hashCollector.mkString("\n")
   }
 
-  def MD5Generator(input : String) : String = {
+  def MD5Generator(input: String): String = {
     val md = MessageDigest.getInstance("MD5")
     md.update(input.getBytes("UTF-8"))
     val digest = md.digest()
     org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
   }
 
-  def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
-    // scalastyle:off
-    val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
-    // Defines Operators that should not generated
+  def absRndClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+    val funcs = buildRandomFunctionList(isSymbol)
+
+    val body = funcs.map(func => {
+      val scalaDoc = generateAPIDocFromBackend(func)
+      val decl = generateAPISignature(func, isSymbol)
+      s"$scalaDoc\n$decl"
+    })
+    writeFile(
+      FILE_PATH,
+      if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase",
+      body)
+  }
+
+  def absClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
     val notGenerated = Set("Custom")
-    // TODO: Add Filter to the same location in case of refactor
-    val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
+    val funcs = buildFunctionList(isSymbol)
+      .filterNot(_.name.startsWith("_"))
       .filterNot(ele => notGenerated.contains(ele.name))
-      .map(absClassFunction => {
-      val scalaDoc = generateAPIDocFromBackend(absClassFunction)
-      val defBody = generateAPISignature(absClassFunction, isSymbol)
-      s"$scalaDoc\n$defBody"
+    val body = funcs.map(func => {
+      val scalaDoc = generateAPIDocFromBackend(func)
+      val decl = generateAPISignature(func, isSymbol)
+      s"$scalaDoc\n$decl"
     })
-    val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
-    val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements.  See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License.  You may obtain a copy of the License at\n*\n*    http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
-    val scalaStyle = "// scalastyle:off"
-    val packageDef = "package org.apache.mxnet"
-    val imports = "import org.apache.mxnet.annotation.Experimental"
-    val absClassDef = s"abstract class $packageName"
-    val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
-    val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
-    pw.write(finalStr)
-    pw.close()
-    MD5Generator(finalStr)
+    writeFile(
+      FILE_PATH,
+      if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
+      body)
   }
 
-  def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
-    // scalastyle:off
-    val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
-    val absFuncs = absClassFunctions.map(absClassFunction => {
-      val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
-      if (isSymbol) {
-        val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol"
-        s"$scalaDoc\n$defBody"
-      } else {
-        val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
-        val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
-        s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
-      }
-    })
+  def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+    val absClassFunctions = buildFunctionList(isSymbol)
+    val absFuncs = absClassFunctions
+      .filterNot(_.name.startsWith("_"))
+      .map(absClassFunction => {
+        val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
+        if (isSymbol) {
+          val defBody =
+            s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)" +
+              s"(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): " +
+              s"org.apache.mxnet.Symbol"
+          s"$scalaDoc\n$defBody"
+        } else {
+          val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)" +
+            s"(args: Any*): " +
+            s"org.apache.mxnet.NDArrayFuncReturn"
+          val defBody = s"def ${absClassFunction.name}(args: Any*): " +
+            s"org.apache.mxnet.NDArrayFuncReturn"
+          s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
+        }
+      })
     val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
-    val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements.  See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License.  You may obtain a copy of the License at\n*\n*    http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
+    writeFile(FILE_PATH, packageName, absFuncs)
+  }
+
+  def writeFile(FILE_PATH: String, packageName: String, body: Seq[String]): String = {
+    val apacheLicence =
+      """/*
+        |* Licensed to the Apache Software Foundation (ASF) under one or more
+        |* contributor license agreements.  See the NOTICE file distributed with
+        |* this work for additional information regarding copyright ownership.
+        |* The ASF licenses this file to You under the Apache License, Version 2.0
+        |* (the "License"); you may not use this file except in compliance with
+        |* the License.  You may obtain a copy of the License at
+        |*
+        |*    http://www.apache.org/licenses/LICENSE-2.0
+        |*
+        |* Unless required by applicable law or agreed to in writing, software
+        |* distributed under the License is distributed on an "AS IS" BASIS,
+        |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+        |* See the License for the specific language governing permissions and
+        |* limitations under the License.
+        |*/
+        |""".stripMargin
     val scalaStyle = "// scalastyle:off"
     val packageDef = "package org.apache.mxnet"
     val imports = "import org.apache.mxnet.annotation.Experimental"
     val absClassDef = s"abstract class $packageName"
-    val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
-    import java.io._
+    val finalStr =
+      s"""$apacheLicence
+         |$scalaStyle
+         |$packageDef
+         |$imports
+         |$absClassDef {
+         |${body.mkString("\n")}
+         |}""".stripMargin
     val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
     pw.write(finalStr)
     pw.close()
@@ -107,20 +143,15 @@ private[mxnet] object APIDocGenerator{
   }
 
   // Generate ScalaDoc type
-  def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = {
+  def generateAPIDocFromBackend(func: absClassFunction, withParam: Boolean = true): String = {
     val desc = ArrayBuffer[String]()
     desc += "  * <pre>"
-      func.desc.split("\n").foreach({ currStr =>
+    func.desc.split("\n").foreach({ currStr =>
       desc += s"  * $currStr"
     })
     desc += "  * </pre>"
     val params = func.listOfArgs.map({ absClassArg =>
-      val currArgName = absClassArg.argName match {
-                case "var" => "vari"
-                case "type" => "typeOf"
-                case _ => absClassArg.argName
-      }
-      s"  * @param $currArgName\t\t${absClassArg.argDesc}"
+      s"  * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}"
     })
     val returnType = s"  * @return ${func.returnType}"
     if (withParam) {
@@ -130,65 +161,23 @@ private[mxnet] object APIDocGenerator{
     }
   }
 
-  def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
-    var argDef = ListBuffer[String]()
-    func.listOfArgs.foreach(absClassArg => {
-      val currArgName = absClassArg.argName match {
-        case "var" => "vari"
-        case "type" => "typeOf"
-        case _ => absClassArg.argName
-      }
-      if (absClassArg.isOptional) {
-        argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
-      }
-      else {
-        argDef += s"$currArgName : ${absClassArg.argType}"
-      }
-    })
-    var returnType = func.returnType
+  def generateAPISignature(func: absClassFunction, isSymbol: Boolean): String = {
+    val argDef = ListBuffer[String]()
+
+    argDef ++= buildArgDefs(func)
+
     if (isSymbol) {
       argDef += "name : String = null"
       argDef += "attr : Map[String, String] = null"
     } else {
       argDef += "out : Option[NDArray] = None"
-      returnType = "org.apache.mxnet.NDArrayFuncReturn"
     }
+
+    val returnType = func.returnType
+
     val experimentalTag = "@Experimental"
     s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
   }
 
 
-  // List and add all the atomic symbol functions to current module.
-  private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
-    val opNames = ListBuffer.empty[String]
-    val returnType = if (isSymbol) "Symbol" else "NDArray"
-    _LIB.mxListAllOpNames(opNames)
-    // TODO: Add '_linalg_', '_sparse_', '_image_' support
-    // TODO: Add Filter to the same location in case of refactor
-    opNames.map(opName => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
-    }).toList.filterNot(_.name.startsWith("_"))
-  }
-
-  // Create an atomic symbol function by handle and function name.
-  private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
-  : absClassFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
-
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
-      val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType)
-      new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
-    }
-    new absClassFunction(aliasName, desc.value, argList.toList, returnType)
-  }
 }
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
new file mode 100644
index 00000000000..84b27ea79fb
--- /dev/null
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.init.Base.{RefInt, RefLong, RefString, _LIB}
+import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
+
+import scala.collection.mutable.ListBuffer
+import scala.reflect.macros.blackbox
+
+abstract class GeneratorBase {
+  type Handle = Long
+
+  case class Arg(argName: String, argType: String, argDesc: String, isOptional: Boolean) {
+    def safeArgName: String = argName match {
+      case "var" => "vari"
+      case "type" => "typeOf"
+      case _ => argName
+    }
+  }
+
+  case class Func(name: String, desc: String, listOfArgs: List[Arg], returnType: String)
+
+  protected def buildFunctionList(isSymbol: Boolean): List[Func] = {
+    val opNames = ListBuffer.empty[String]
+    _LIB.mxListAllOpNames(opNames)
+    opNames.map(opName => {
+      val opHandle = new RefLong
+      _LIB.nnGetOpHandle(opName, opHandle)
+      makeAtomicFunction(opHandle.value, opName, isSymbol)
+    }).toList
+  }
+
+  protected def buildRandomFunctionList(isSymbol: Boolean): List[Func] = {
+    buildFunctionList(isSymbol)
+      .filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_"))
+      .map(f => f.copy(name = f.name.stripPrefix("_")))
+      // unify _random and _sample
+      .map(f => unifyRandom(f, isSymbol))
+      // deduplicate
+      .groupBy(_.name)
+      .mapValues(_.head)
+      .values
+      .toList
+  }
+
+  protected def unifyRandom(func: Func, isSymbol: Boolean): Func = {
+    var typeConv = if (isSymbol)
+      Map(
+        "org.apache.mxnet.Symbol" -> "Any",
+        "org.apache.mxnet.Base.MXFloat" -> "Any",
+        "Int" -> "Any"
+      )
+    else
+      Map(
+        "org.apache.mxnet.NDArray" -> "Any",
+        "org.apache.mxnet.Base.MXFloat" -> "Any",
+        "Int" -> "Any"
+      )
+
+    func.copy(
+      name = func.name.replaceAll("(random|sample)_", ""),
+      listOfArgs = func.listOfArgs
+        .map { arg =>
+          // This is hack to manage the fact that random_normal and sample_normal have
+          //  non-consistent parameter naming in the back-end
+          if (arg.argName == "loc") arg.copy(argName = "mu")
+          else if (arg.argName == "scale") arg.copy(argName = "sigma")
+          else arg
+        }
+        .map { arg =>
+          arg.copy(argType = typeConv.getOrElse(arg.argType, arg.argType))
+        }
+    )
+  }
+
+  protected def makeAtomicFunction(handle: Handle, aliasName: String, isSymbol: Boolean): Func = {
+    val name = new RefString
+    val desc = new RefString
+    val keyVarNumArgs = new RefString
+    val numArgs = new RefInt
+    val argNames = ListBuffer.empty[String]
+    val argTypes = ListBuffer.empty[String]
+    val argDescs = ListBuffer.empty[String]
+
+    _LIB.mxSymbolGetAtomicSymbolInfo(
+      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
+    val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs)
+    val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) {
+      s"This function support variable length of positional input (${keyVarNumArgs.value})."
+    } else {
+      ""
+    }
+    val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})"
+    val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
+    // scalastyle:off println
+    if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
+      && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
+      println("Function definition:\n" + docStr)
+    }
+    // scalastyle:on println
+    val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
+      val family = if (isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArray"
+      val typeAndOption =
+        CToScalaUtils.argumentCleaner(argName, argType, family)
+      Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
+    }
+    val returnType = if (isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArrayFuncReturn"
+    Func(aliasName, desc.value, argList.toList, returnType)
+  }
+
+  /**
+    * Generate class structure for all function APIs
+    *
+    * @param c
+    * @param funcDef DefDef type of function definitions
+    * @param annottees
+    * @return
+    */
+  protected def structGeneration(c: blackbox.Context)
+                                (funcDef: List[c.universe.DefDef], annottees: c.Expr[Any]*)
+  : c.Expr[Any] = {
+    import c.universe._
+    val inputs = annottees.map(_.tree).toList
+    // pattern match on the inputs
+    val modDefs = inputs map {
+      case ClassDef(mods, name, something, template) =>
+        val q = template match {
+          case Template(superMaybe, emptyValDef, defs) =>
+            Template(superMaybe, emptyValDef, defs ++ funcDef)
+          case ex =>
+            throw new IllegalArgumentException(s"Invalid template: $ex")
+        }
+        ClassDef(mods, name, something, q)
+      case ModuleDef(mods, name, template) =>
+        val q = template match {
+          case Template(superMaybe, emptyValDef, defs) =>
+            Template(superMaybe, emptyValDef, defs ++ funcDef)
+          case ex =>
+            throw new IllegalArgumentException(s"Invalid template: $ex")
+        }
+        ModuleDef(mods, name, q)
+      case ex =>
+        throw new IllegalArgumentException(s"Invalid macro input: $ex")
+    }
+    // wrap the result up in an Expr, and return it
+    val result = c.Expr(Block(modDefs, Literal(Constant())))
+    result
+  }
+
+  protected def buildArgDefs(func: Func): List[String] = {
+    func.listOfArgs.map(arg =>
+      if (arg.isOptional)
+        s"${arg.safeArgName} : Option[${arg.argType}] = None"
+      else
+        s"${arg.safeArgName} : ${arg.argType}"
+    )
+  }
+
+
+}
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index 2d3a1c7ec5a..8b7ad67ba8b 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -17,11 +17,8 @@
 
 package org.apache.mxnet
 
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
-
 import scala.annotation.StaticAnnotation
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import scala.collection.mutable.ListBuffer
 import scala.language.experimental.macros
 import scala.reflect.macros.blackbox
 
@@ -33,21 +30,37 @@ private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation
   private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs
 }
 
-private[mxnet] object NDArrayMacro {
-  case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
-  case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
+private[mxnet] class AddNDArrayRandomAPIs(isContrib: Boolean) extends StaticAnnotation {
+  private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeRandomAPIDefs
+}
+
+
+private[mxnet] object NDArrayMacro extends GeneratorBase {
+  type NDArrayArg = Arg
+  type NDArrayFunction = Func
 
   // scalastyle:off havetype
   def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
     impl(c)(annottees: _*)
   }
+
   def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    typeSafeAPIImpl(c)(annottees: _*)
+    typedAPIImpl(c)(annottees: _*)
   }
+
+  def typeSafeRandomAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+    typedRandomAPIImpl(c)(annottees: _*)
+  }
+
   // scalastyle:off havetype
 
-  private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
+  private val ndarrayFunctions = buildFunctionList(false)
 
+  private val rndFunctions = buildRandomFunctionList(false)
+
+  /**
+    * Implementation for fixed input API structure
+    */
   private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
 
@@ -60,26 +73,29 @@ private[mxnet] object NDArrayMacro {
       else ndarrayFunctions.filterNot(_.name.startsWith("_"))
     }
 
-     val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
-        val funcName = NDArrayfunction.name
-        val termName = TermName(funcName)
-       Seq(
-            // scalastyle:off
-            // (yizhi) We are investigating a way to make these functions type-safe
-            // and waiting to see the new approach is stable enough.
-            // Thus these functions may be deprecated in the future.
-            // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
-            q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
-            // e.g def transpose(args: Any*)
-            q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
-            // scalastyle:on
-          )
-        }
-
-    structGeneration(c)(functionDefs, annottees : _*)
+    val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
+      val funcName = NDArrayfunction.name
+      val termName = TermName(funcName)
+      Seq(
+        // scalastyle:off
+        // (yizhi) We are investigating a way to make these functions type-safe
+        // and waiting to see the new approach is stable enough.
+        // Thus these functions may be deprecated in the future.
+        // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
+        q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
+        // e.g def transpose(args: Any*)
+        q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
+        // scalastyle:on
+      )
+    }
+
+    structGeneration(c)(functionDefs, annottees: _*)
   }
 
-  private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
+  /**
+    * Implementation for Dynamic typed API NDArray.api.<functioname>
+    */
+  private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
 
     val isContrib: Boolean = c.prefix.tree match {
@@ -91,146 +107,131 @@ private[mxnet] object NDArrayMacro {
     val newNDArrayFunctions = {
       if (isContrib) ndarrayFunctions.filter(
         func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
-      else ndarrayFunctions.filterNot(_.name.startsWith("_"))
+      else ndarrayFunctions.filterNot(f => f.name.startsWith("_"))
     }.filterNot(ele => notGenerated.contains(ele.name))
 
-    val functionDefs = newNDArrayFunctions.map { ndarrayfunction =>
-
-      // Construct argument field
-      var argDef = ListBuffer[String]()
-      // Construct Implementation field
-      var impl = ListBuffer[String]()
-      impl += "val map = scala.collection.mutable.Map[String, Any]()"
-      impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
-      ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
-        // var is a special word used to define variable in Scala,
-        // need to changed to something else in order to make it work
-        val currArgName = ndarrayarg.argName match {
-          case "var" => "vari"
-          case "type" => "typeOf"
-          case default => ndarrayarg.argName
-        }
-        if (ndarrayarg.isOptional) {
-          argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
-        }
-        else {
-          argDef += s"${currArgName} : ${ndarrayarg.argType}"
-        }
-        // NDArray arg implementation
-        val returnType = "org.apache.mxnet.NDArray"
+    val functionDefs = newNDArrayFunctions.map(f => buildTypedFunction(c)(f))
+
+    structGeneration(c)(functionDefs, annottees: _*)
+  }
+
+  /**
+    * Implementation for Dynamic typed API NDArray.random.<functioname>
+    */
+  private def typedRandomAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
+    val functionDefs = rndFunctions.map(f => buildRandomTypedFunction(c)(f))
+    structGeneration(c)(functionDefs, annottees: _*)
+  }
+
+  private def buildTypedFunction(c: blackbox.Context)
+                                (function: NDArrayFunction): c.universe.DefDef = {
+    import c.universe._
+
+    val returnType = "org.apache.mxnet.NDArrayFuncReturn"
+    val arrayType = "org.apache.mxnet.NDArray"
 
+    // Construct argument field
+    val argDef = ListBuffer[String]()
+    argDef ++= buildArgDefs(function)
+    argDef += "out : Option[NDArray] = None"
+
+    // Construct Implementation field
+    var impl = ListBuffer[String]()
+    impl += "val map = scala.collection.mutable.Map[String, Any]()"
+    impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
+
+    // NDArray arg implementation
+    impl ++=
+      function.listOfArgs.map { ndarrayarg =>
         // TODO: Currently we do not add place holder for NDArray
         // Example: an NDArray operator like the following format
         // nd.foo(arg1: NDArray(required), arg2: NDArray(Optional), arg3: NDArray(Optional)
         // If we place nd.foo(arg1, arg3 = arg3), do we need to add place holder for arg2?
         // What it should be?
         val base =
-          if (ndarrayarg.argType.equals(returnType)) {
-            s"args += $currArgName"
-          } else if (ndarrayarg.argType.equals(s"Array[$returnType]")){
-            s"args ++= $currArgName"
+          if (ndarrayarg.argType.equals(arrayType)) {
+            s"args += ${ndarrayarg.safeArgName}"
+          } else if (ndarrayarg.argType.equals(s"Array[$arrayType]")) {
+            s"args ++= ${ndarrayarg.safeArgName}"
           } else {
-            "map(\"" + ndarrayarg.argName + "\") = " + currArgName
+            s"""map("${ndarrayarg.argName}") = ${ndarrayarg.safeArgName}"""
           }
-        impl.append(
-          if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get"
-          else base
-        )
-      })
-      // add default out parameter
-      argDef += "out : Option[NDArray] = None"
-      impl += "if (!out.isEmpty) map(\"out\") = out.get"
-      // scalastyle:off
-      impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
-      // scalastyle:on
-      // Combine and build the function string
-      val returnType = "org.apache.mxnet.NDArrayFuncReturn"
-      var finalStr = s"def ${ndarrayfunction.name}"
-      finalStr += s" (${argDef.mkString(",")}) : $returnType"
-      finalStr += s" = {${impl.mkString("\n")}}"
-      c.parse(finalStr).asInstanceOf[DefDef]
-    }
-
-    structGeneration(c)(functionDefs, annottees : _*)
+        if (ndarrayarg.isOptional) s"if (!${ndarrayarg.safeArgName}.isEmpty) $base.get"
+        else base
+      }
+
+    impl += "if (!out.isEmpty) map(\"out\") = out.get"
+    impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(" +
+      s""""${function.name}", args.toSeq, map.toMap)"""
+
+    // Combine and build the function string
+    var finalStr = s"def ${function.name}"
+    finalStr += s" (${argDef.mkString(",")}) : $returnType"
+    finalStr += s" = {${impl.mkString("\n")}}"
+    c.parse(finalStr).asInstanceOf[DefDef]
   }
 
-  private def structGeneration(c: blackbox.Context)
-                              (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*)
-  : c.Expr[Any] = {
+  private def buildRandomTypedFunction(c: blackbox.Context)
+                                      (function: NDArrayFunction): c.universe.DefDef = {
     import c.universe._
-    val inputs = annottees.map(_.tree).toList
-    // pattern match on the inputs
-    val modDefs = inputs map {
-      case ClassDef(mods, name, something, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ClassDef(mods, name, something, q)
-      case ModuleDef(mods, name, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ModuleDef(mods, name, q)
-      case ex =>
-        throw new IllegalArgumentException(s"Invalid macro input: $ex")
-    }
-    // wrap the result up in an Expr, and return it
-    val result = c.Expr(Block(modDefs, Literal(Constant())))
-    result
-  }
-
-
-
 
-  // List and add all the atomic symbol functions to current module.
-  private def initNDArrayModule(): List[NDArrayFunction] = {
-    val opNames = ListBuffer.empty[String]
-    _LIB.mxListAllOpNames(opNames)
-    opNames.map(opName => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeNDArrayFunction(opHandle.value, opName)
-    }).toList
-  }
-
-  // Create an atomic symbol function by handle and function name.
-  private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
-  : NDArrayFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
-
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs)
-    val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) {
-      s"This function support variable length of positional input (${keyVarNumArgs.value})."
+    val returnType = "org.apache.mxnet.NDArrayFuncReturn"
+    val arrayType = "org.apache.mxnet.NDArray"
+
+    // Construct argument field
+    val argDef = ListBuffer[String]()
+    argDef ++= buildArgDefs(function)
+    argDef += "out : Option[NDArray] = None"
+
+    // Construct Implementation field
+    var impl = ListBuffer[String]()
+    impl += "val map = scala.collection.mutable.Map[String, Any]()"
+    impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
+
+    // determine what target to call
+    val arg = function.listOfArgs.filter(arg => arg.argType == "Any").head
+    if(arg.isOptional) {
+      impl +=
+        s"""val target = ${arg.safeArgName} match {
+           |   case Some(a:$arrayType) => "sample_${function.name}"
+           |   case None => "sample_${function.name}"
+           |   case _ => "random_${function.name}"
+           |}
+      """.stripMargin
     } else {
-      ""
-    }
-    val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})"
-    val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
-    // scalastyle:off println
-    if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
-      && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
-      println("NDArray function definition:\n" + docStr)
+      impl +=
+        s"""val target = ${arg.safeArgName} match {
+           |   case _:$arrayType => "sample_${function.name}"
+           |   case _ => "random_${function.name}"
+           |}
+      """.stripMargin
     }
-    // scalastyle:on println
-    val argList = argNames zip argTypes map { case (argName, argType) =>
-      val typeAndOption =
-        CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.NDArray")
-      new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
-    }
-    new NDArrayFunction(aliasName, argList.toList)
+
+    // NDArray arg implementation
+    impl ++=
+      function.listOfArgs.map { ndarrayarg =>
+        // no Array[] in random/sample module, but let's keep that for a future case
+        val base =
+          if (ndarrayarg.argType.equals(arrayType)) {
+            s"args += ${ndarrayarg.safeArgName}"
+          } else if (ndarrayarg.argType.equals(s"Array[$arrayType]")) {
+            s"args ++= ${ndarrayarg.safeArgName}"
+          } else {
+            s"""map("${ndarrayarg.argName}") = ${ndarrayarg.safeArgName}"""
+          }
+        if (ndarrayarg.isOptional) s"if (!${ndarrayarg.safeArgName}.isEmpty) $base.get"
+        else base
+      }
+
+    impl += "if (!out.isEmpty) map(\"out\") = out.get"
+    impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(" +
+      s"target, args.toSeq, map.toMap)"
+
+    // Combine and build the function string
+    var finalStr = s"def ${function.name}"
+    finalStr += s" (${argDef.mkString(",")}) : $returnType"
+    finalStr += s" = {${impl.mkString("\n")}}"
+    c.parse(finalStr).asInstanceOf[DefDef]
   }
+
 }
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
index 42aa11781d8..4c1cffd2410 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
@@ -21,8 +21,6 @@ import scala.annotation.StaticAnnotation
 import scala.collection.mutable.ListBuffer
 import scala.language.experimental.macros
 import scala.reflect.macros.blackbox
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
 
 private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnotation {
   private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs
@@ -32,20 +30,31 @@ private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation
   private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typeSafeAPIDefs
 }
 
-private[mxnet] object SymbolImplMacros {
-  case class SymbolArg(argName: String, argType: String, isOptional : Boolean)
-  case class SymbolFunction(name: String, listOfArgs: List[SymbolArg])
+private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends StaticAnnotation {
+  private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typedRandomAPIDefs
+}
+
+private[mxnet] object SymbolImplMacros extends GeneratorBase {
+  type SymbolArg = Arg
+  type SymbolFunction = Func
 
   // scalastyle:off havetype
   def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
     impl(c)(annottees: _*)
   }
+
   def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
     typedAPIImpl(c)(annottees: _*)
   }
+
+  def typedRandomAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+    typedRandomAPIImpl(c)(annottees: _*)
+  }
   // scalastyle:on havetype
 
-  private val symbolFunctions: List[SymbolFunction] = initSymbolModule()
+  private val symbolFunctions = buildFunctionList(true)
+
+  private val rndFunctions = buildRandomFunctionList(true)
 
   /**
     * Implementation for fixed input API structure
@@ -60,29 +69,28 @@ private[mxnet] object SymbolImplMacros {
     val newSymbolFunctions = {
       if (isContrib) symbolFunctions.filter(
         func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
-      else symbolFunctions.filter(!_.name.startsWith("_"))
+      else symbolFunctions.filterNot(_.name.startsWith("_"))
     }
 
-
     val functionDefs = newSymbolFunctions map { symbolfunction =>
-        val funcName = symbolfunction.name
-        val tName = TermName(funcName)
-        q"""
+      val funcName = symbolfunction.name
+      val tName = TermName(funcName)
+      q"""
             def $tName(name : String = null, attr : Map[String, String] = null)
             (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null)
              : org.apache.mxnet.Symbol = {
               createSymbolGeneral($funcName,name,attr,args,kwargs)
               }
          """.asInstanceOf[DefDef]
-      }
+    }
 
-    structGeneration(c)(functionDefs, annottees : _*)
+    structGeneration(c)(functionDefs, annottees: _*)
   }
 
   /**
     * Implementation for Dynamic typed API Symbol.api.<functioname>
     */
-  private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
+  private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
 
     val isContrib: Boolean = c.prefix.tree match {
@@ -97,146 +105,129 @@ private[mxnet] object SymbolImplMacros {
     val newSymbolFunctions = {
       if (isContrib) symbolFunctions.filter(
         func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
-      else symbolFunctions.filter(!_.name.startsWith("_"))
+      else symbolFunctions.filterNot(_.name.startsWith("_"))
     }.filterNot(ele => notGenerated.contains(ele.name))
 
-    val functionDefs = newSymbolFunctions map { symbolfunction =>
+    val functionDefs = newSymbolFunctions.map(f => buildTypedFunction(c)(f))
+    structGeneration(c)(functionDefs, annottees: _*)
+  }
 
-      // Construct argument field
-      var argDef = ListBuffer[String]()
-      // Construct Implementation field
-      var impl = ListBuffer[String]()
-      impl += "val map = scala.collection.mutable.Map[String, Any]()"
-      impl += "var args = Seq[org.apache.mxnet.Symbol]()"
-      symbolfunction.listOfArgs.foreach({ symbolarg =>
-        // var is a special word used to define variable in Scala,
-        // need to changed to something else in order to make it work
-        val currArgName = symbolarg.argName match {
-          case "var" => "vari"
-          case "type" => "typeOf"
-          case default => symbolarg.argName
-        }
+  private def typedRandomAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
+    val functionDefs = rndFunctions.map(f => buildRandomTypedFunction(c)(f))
+    structGeneration(c)(functionDefs, annottees: _*)
+  }
+
+  private def buildTypedFunction(c: blackbox.Context)
+                                (function: SymbolFunction): c.universe.DefDef = {
+    import c.universe._
+
+    val returnType = "org.apache.mxnet.Symbol"
+    val symbolType = "org.apache.mxnet.Symbol"
+
+    // Construct argument field
+    val argDef = ListBuffer[String]()
+    argDef ++= buildArgDefs(function)
+    argDef += "name : String = null"
+    argDef += "attr : Map[String, String] = null"
+
+    // Construct Implementation field
+    var impl = ListBuffer[String]()
+    impl += "val map = scala.collection.mutable.Map[String, Any]()"
+    impl += "var args = Seq[org.apache.mxnet.Symbol]()"
+
+    // Symbol arg implementation
+    impl ++= function.listOfArgs.map { symbolarg =>
+      if (symbolarg.argType.equals(s"Array[$symbolType]")) {
+        if (symbolarg.isOptional)
+          s"if (!${symbolarg.safeArgName}.isEmpty) args = ${symbolarg.safeArgName}.get.toSeq"
+        else
+          s"args = ${symbolarg.safeArgName}.toSeq"
+      } else {
         if (symbolarg.isOptional) {
-          argDef += s"${currArgName} : Option[${symbolarg.argType}] = None"
+          s"if (!${symbolarg.safeArgName}.isEmpty) " +
+            s"""map("${symbolarg.argName}") = ${symbolarg.safeArgName}.get"""
         }
         else {
-          argDef += s"${currArgName} : ${symbolarg.argType}"
+          s"""map("${symbolarg.argName}") = ${symbolarg.safeArgName}"""
         }
-        // Symbol arg implementation
-        val returnType = "org.apache.mxnet.Symbol"
-        val base =
-        if (symbolarg.argType.equals(s"Array[$returnType]")) {
-          if (symbolarg.isOptional) s"if (!$currArgName.isEmpty) args = $currArgName.get.toSeq"
-          else s"args = $currArgName.toSeq"
-        } else {
-          if (symbolarg.isOptional) {
-            // scalastyle:off
-            s"if (!$currArgName.isEmpty) map(" + "\"" + symbolarg.argName + "\"" + s") = $currArgName.get"
-            // scalastyle:on
-          }
-          else "map(\"" + symbolarg.argName + "\"" + s") = $currArgName"
-        }
-
-        impl += base
-      })
-      argDef += "name : String = null"
-      argDef += "attr : Map[String, String] = null"
-      // scalastyle:off
-      // TODO: Seq() here allows user to place Symbols rather than normal arguments to run, need to fix if old API deprecated
-      impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, args, map.toMap)"
-      // scalastyle:on
-      // Combine and build the function string
-      val returnType = "org.apache.mxnet.Symbol"
-      var finalStr = s"def ${symbolfunction.name}"
-      finalStr += s" (${argDef.mkString(",")}) : $returnType"
-      finalStr += s" = {${impl.mkString("\n")}}"
-      c.parse(finalStr).asInstanceOf[DefDef]
+      }
     }
-    structGeneration(c)(functionDefs, annottees : _*)
+
+    impl += "org.apache.mxnet.Symbol.createSymbolGeneral(" +
+      s""""${function.name}", name, attr, args, map.toMap)"""
+
+    // Combine and build the function string
+    var finalStr = s"def ${function.name}"
+    finalStr += s" (${argDef.mkString(",")}) : $returnType"
+    finalStr += s" = {${impl.mkString("\n")}}"
+    c.parse(finalStr).asInstanceOf[DefDef]
   }
 
-  /**
-    * Generate class structure for all function APIs
-    * @param c
-    * @param funcDef DefDef type of function definitions
-    * @param annottees
-    * @return
-    */
-  private def structGeneration(c: blackbox.Context)
-                              (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*)
-  : c.Expr[Any] = {
+
+  private def buildRandomTypedFunction(c: blackbox.Context)
+                                      (function: SymbolFunction): c.universe.DefDef = {
+
     import c.universe._
-    val inputs = annottees.map(_.tree).toList
-    // pattern match on the inputs
-    val modDefs = inputs map {
-      case ClassDef(mods, name, something, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ClassDef(mods, name, something, q)
-      case ModuleDef(mods, name, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ModuleDef(mods, name, q)
-      case ex =>
-        throw new IllegalArgumentException(s"Invalid macro input: $ex")
-    }
-    // wrap the result up in an Expr, and return it
-    val result = c.Expr(Block(modDefs, Literal(Constant())))
-    result
-  }
 
-  // List and add all the atomic symbol functions to current module.
-  private def initSymbolModule(): List[SymbolFunction] = {
-    val opNames = ListBuffer.empty[String]
-    _LIB.mxListAllOpNames(opNames)
-    // TODO: Add '_linalg_', '_sparse_', '_image_' support
-    opNames.map(opName => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeAtomicSymbolFunction(opHandle.value, opName)
-    }).toList
-  }
+    val returnType = "org.apache.mxnet.Symbol"
+    val symbolType = "org.apache.mxnet.Symbol"
+
+    // Construct argument field
+    val argDef = ListBuffer[String]()
+    argDef ++= buildArgDefs(function)
+    argDef += "name : String = null"
+    argDef += "attr : Map[String, String] = null"
+
+    // Construct Implementation field
+    var impl = ListBuffer[String]()
+    impl += "val map = scala.collection.mutable.Map[String, Any]()"
+    impl += "var args = Seq[org.apache.mxnet.Symbol]()"
+
+    // determine what target to call
+    val arg = function.listOfArgs.filter(arg => arg.argType == "Any").head
+    if(arg.isOptional) {
+      impl +=
+        s"""val target = ${arg.safeArgName} match {
+           |   case Some(s:$symbolType) => "sample_${function.name}"
+           |   case None => "sample_${function.name}"
+           |   case _ => "random_${function.name}"
+           |}
+      """.stripMargin
+    } else {
+      impl +=
+        s"""val target = ${arg.safeArgName} match {
+           |   case _:$symbolType => "sample_${function.name}"
+           |   case _ => "random_${function.name}"
+           |}
+      """.stripMargin
+    }
 
-  // Create an atomic symbol function by handle and function name.
-  private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String)
-      : SymbolFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
-
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs)
-    val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) {
-        s"This function support variable length of positional input (${keyVarNumArgs.value})."
+    // Symbol arg implementation
+    impl ++= function.listOfArgs.map { symbolarg =>
+      // no Array[] in random/sample module, but let's keep that for a future case
+      if (symbolarg.argType.equals(s"Array[$symbolType]")) {
+        if (symbolarg.isOptional)
+          s"if (!${symbolarg.safeArgName}.isEmpty) args = ${symbolarg.safeArgName}.get.toSeq"
+        else
+          s"args = ${symbolarg.safeArgName}.toSeq"
       } else {
-        ""
+        if (symbolarg.isOptional) {
+          s"if (!${symbolarg.safeArgName}.isEmpty) " +
+            s"""map("${symbolarg.argName}") = ${symbolarg.safeArgName}.get"""
+        }
+        else
+          s"""map("${symbolarg.argName}") = ${symbolarg.safeArgName}"""
       }
-    val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})"
-    val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
-    // scalastyle:off println
-    if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
-          && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
-      println("Symbol function definition:\n" + docStr)
     }
-    // scalastyle:on println
-    val argList = argNames zip argTypes map { case (argName, argType) =>
-        val typeAndOption =
-          CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.Symbol")
-        new SymbolArg(argName, typeAndOption._1, typeAndOption._2)
-    }
-    new SymbolFunction(aliasName, argList.toList)
+
+    impl += "org.apache.mxnet.Symbol.createSymbolGeneral(" +
+      s"target, name, attr, args, map.toMap)"
+
+    // Combine and build the function string
+    var finalStr = s"def ${function.name}"
+    finalStr += s" (${argDef.mkString(",")}) : $returnType"
+    finalStr += s" = {${impl.mkString("\n")}}"
+    c.parse(finalStr).asInstanceOf[DefDef]
   }
+
 }
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index d0ebe5b1d2c..c6344d7c7ca 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -22,12 +22,12 @@ private[mxnet] object CToScalaUtils {
 
   // Convert C++ Types to Scala Types
   def typeConversion(in : String, argType : String = "",
-                     argName : String, returnType : String) : String = {
+                     argName : String, familyType : String) : String = {
     in match {
       case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
-      case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
+      case "Symbol" | "NDArray" | "NDArray-or-Symbol" => familyType
       case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
-      => s"Array[$returnType]"
+      => s"Array[$familyType]"
       case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat"
       case "int" | "intorNone" | "int(non-negative)" => "Int"
       case "long" | "long(non-negative)" => "Long"
@@ -53,7 +53,7 @@ private[mxnet] object CToScalaUtils {
     * @return (Scala_Type, isOptional)
     */
   def argumentCleaner(argName: String,
-                      argType : String, returnType : String) : (String, Boolean) = {
+                      argType : String, familyType : String) : (String, Boolean) = {
     val spaceRemoved = argType.replaceAll("\\s+", "")
     var commaRemoved : Array[String] = new Array[String](0)
     // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
@@ -71,9 +71,9 @@ private[mxnet] object CToScalaUtils {
         s"""expected "optional" got ${commaRemoved(1)}""")
       require(commaRemoved(2).startsWith("default="),
         s"""expected "default=..." got ${commaRemoved(2)}""")
-      (typeConversion(commaRemoved(0), argType, argName, returnType), true)
+      (typeConversion(commaRemoved(0), argType, argName, familyType), true)
     } else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
-      val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
+      val tempType = typeConversion(commaRemoved(0), argType, argName, familyType)
       val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
       (tempType, tempOptional)
     } else {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services