You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/10/15 20:59:13 UTC

[incubator-mxnet] branch java-api updated: [MXNET-984] Add Java NDArray and introduce Java Operator Builder class (#12816)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch java-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/java-api by this push:
     new 64b9ac7  [MXNET-984] Add Java NDArray and introduce Java Operator Builder class (#12816)
64b9ac7 is described below

commit 64b9ac776bc79344a6bf8f6baf30dc42ea494ca2
Author: Lanking <la...@live.com>
AuthorDate: Mon Oct 15 13:58:57 2018 -0700

    [MXNET-984] Add Java NDArray and introduce Java Operator Builder class (#12816)
    
    * clean history and add commit
    
    * add lint header
    
    * bypass the java unittest when make the package
    
    * clean up redundant test
    
    * clean spacing issue
    
    * revert the change
    
    * clean up
    
    * cleanup the JMacros
    
    * adding line escape
    
    * revert some changes and fix scala style
    
    * fixes regarding to Naveen's comment
---
 Makefile                                           |   2 +-
 scala-package/core/pom.xml                         |   5 +-
 .../scala/org/apache/mxnet/javaapi/Context.scala   |   1 -
 .../scala/org/apache/mxnet/javaapi/NDArray.scala   | 202 ++++++++++++++++++++
 .../java/org/apache/mxnet/javaapi/NDArrayTest.java |  85 +++++++++
 .../apache/mxnet/javaapi/JavaNDArrayMacro.scala    | 203 +++++++++++++++++++++
 .../org/apache/mxnet/utils/CToScalaUtils.scala     |  22 +--
 .../test/scala/org/apache/mxnet/MacrosSuite.scala  |   2 +-
 8 files changed, 507 insertions(+), 15 deletions(-)

diff --git a/Makefile b/Makefile
index a4b41b8..fe2df2c 100644
--- a/Makefile
+++ b/Makefile
@@ -606,7 +606,7 @@ scalaclean:
 
 scalapkg:
 	(cd $(ROOTDIR)/scala-package; \
-		mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE) -Dcxx="$(CXX)" \
+		mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE),integrationtest -Dcxx="$(CXX)" \
 		    -Dbuild.platform="$(SCALA_PKG_PROFILE)" \
 			-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
 			-Dcurrent_libdir="$(ROOTDIR)/lib" \
diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index ea3a2d6..6e2d8d6 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -86,7 +86,10 @@
         <artifactId>maven-surefire-plugin</artifactId>
         <version>2.22.0</version>
         <configuration>
-          <skipTests>false</skipTests>
+          <argLine>
+            -Djava.library.path=${project.parent.basedir}/native/${platform}/target
+          </argLine>
+          <skipTests>${skipTests}</skipTests>
         </configuration>
       </plugin>
       <plugin>
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
index 5f0caed..2f4f3e6 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
@@ -42,6 +42,5 @@ object Context {
   val gpu: Context = org.apache.mxnet.Context.gpu()
   val devtype2str = org.apache.mxnet.Context.devstr2type.asJava
   val devstr2type = org.apache.mxnet.Context.devstr2type.asJava
-
   def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
new file mode 100644
index 0000000..c77b440
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import org.apache.mxnet.javaapi.DType.DType
+
+import collection.JavaConverters._
+
+@AddJNDArrayAPIs(false)
+object NDArray {
+  implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd)
+
+  implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd
+
+  def waitall(): Unit = org.apache.mxnet.NDArray.waitall()
+
+  def onehotEncode(indices: NDArray, out: NDArray): NDArray
+  = org.apache.mxnet.NDArray.onehotEncode(indices, out)
+
+  def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+  = org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
+  def empty(ctx: Context, shape: Array[Int]): NDArray
+  = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+  def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+  = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+  def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+  = org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
+  def zeros(ctx: Context, shape: Array[Int]): NDArray
+  = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+  def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+  = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+  def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+  = org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
+  def ones(ctx: Context, shape: Array[Int]): NDArray
+  = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+  def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+  = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+  def full(shape: Shape, value: Float, ctx: Context): NDArray
+  = org.apache.mxnet.NDArray.full(shape, value, ctx)
+
+  def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+  def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+  def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+
+  def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+  def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+  def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+
+  def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+  def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+  def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+
+  def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+  def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+
+  def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+  def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+
+  def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+  def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+
+  def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray
+  = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+  def greaterEqual(lhs: NDArray, rhs: Float): NDArray
+  = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+
+  def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+  def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+
+  def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray
+  = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+  def lesserEqual(lhs: NDArray, rhs: Float): NDArray
+  = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+
+  def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: Context = null): NDArray
+  = org.apache.mxnet.NDArray.array(
+    sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)
+
+  def arange(start: Float, stop: Float, step: Float, repeat: Int,
+             ctx: Context, dType: DType.DType): NDArray =
+    org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, dType)
+}
+
+class NDArray(val nd : org.apache.mxnet.NDArray ) {
+
+  def this(arr : Array[Float], shape : Shape, ctx : Context) = {
+    this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
+  }
+
+  def this(arr : java.util.List[java.lang.Float], shape : Shape, ctx : Context) = {
+    this(NDArray.array(arr, shape, ctx))
+  }
+
+  def serialize() : Array[Byte] = nd.serialize()
+
+  def dispose() : Unit = nd.dispose()
+  def disposeDeps() : NDArray = nd.disposeDepsExcept()
+  // def disposeDepsExcept(arr : Array[NDArray]) : NDArray = nd.disposeDepsExcept()
+
+  def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop)
+
+  def slice (i : Int) : NDArray = nd.slice(i)
+
+  def at(idx : Int) : NDArray = nd.at(idx)
+
+  def T : NDArray = nd.T
+
+  def dtype : DType = nd.dtype
+
+  def asType(dtype : DType) : NDArray = nd.asType(dtype)
+
+  def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims)
+
+  def waitToRead(): Unit = nd.waitToRead()
+
+  def context : Context = nd.context
+
+  def set(value : Float) : NDArray = nd.set(value)
+  def set(other : NDArray) : NDArray = nd.set(other)
+  def set(other : Array[Float]) : NDArray = nd.set(other)
+
+  def add(other : NDArray) : NDArray = this.nd + other.nd
+  def add(other : Float) : NDArray = this.nd + other
+  def _add(other : NDArray) : NDArray = this.nd += other
+  def _add(other : Float) : NDArray = this.nd += other
+  def subtract(other : NDArray) : NDArray = this.nd - other
+  def subtract(other : Float) : NDArray = this.nd - other
+  def _subtract(other : NDArray) : NDArray = this.nd -= other
+  def _subtract(other : Float) : NDArray = this.nd -= other
+  def multiply(other : NDArray) : NDArray = this.nd * other
+  def multiply(other : Float) : NDArray = this.nd * other
+  def _multiply(other : NDArray) : NDArray = this.nd *= other
+  def _multiply(other : Float) : NDArray = this.nd *= other
+  def div(other : NDArray) : NDArray = this.nd / other
+  def div(other : Float) : NDArray = this.nd / other
+  def _div(other : NDArray) : NDArray = this.nd /= other
+  def _div(other : Float) : NDArray = this.nd /= other
+  def pow(other : NDArray) : NDArray = this.nd ** other
+  def pow(other : Float) : NDArray = this.nd ** other
+  def _pow(other : NDArray) : NDArray = this.nd **= other
+  def _pow(other : Float) : NDArray = this.nd **= other
+  def mod(other : NDArray) : NDArray = this.nd % other
+  def mod(other : Float) : NDArray = this.nd % other
+  def _mod(other : NDArray) : NDArray = this.nd %= other
+  def _mod(other : Float) : NDArray = this.nd %= other
+  def greater(other : NDArray) : NDArray = this.nd > other
+  def greater(other : Float) : NDArray = this.nd > other
+  def greaterEqual(other : NDArray) : NDArray = this.nd >= other
+  def greaterEqual(other : Float) : NDArray = this.nd >= other
+  def lesser(other : NDArray) : NDArray = this.nd < other
+  def lesser(other : Float) : NDArray = this.nd < other
+  def lesserEqual(other : NDArray) : NDArray = this.nd <= other
+  def lesserEqual(other : Float) : NDArray = this.nd <= other
+
+  def toArray : Array[Float] = nd.toArray
+
+  def toScalar : Float = nd.toScalar
+
+  def copyTo(other : NDArray) : NDArray = nd.copyTo(other)
+
+  def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx)
+
+  def copy() : NDArray = copyTo(this.context)
+
+  def shape : Shape = nd.shape
+
+  def size : Int = shape.product
+
+  def asInContext(context: Context): NDArray = nd.asInContext(context)
+
+  override def equals(obj: Any): Boolean = nd.equals(obj)
+  override def hashCode(): Int = nd.hashCode
+}
+
+object NDArrayFuncReturn {
+  implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn)
+  : org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn
+  implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn)
+  : NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn)
+}
+
+private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) {
+  def head : NDArray = ndFuncReturn.head
+  def get : NDArray = ndFuncReturn.get
+  def apply(i : Int) : NDArray = ndFuncReturn.apply(i)
+  // TODO: Add JavaNDArray operational stuff
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
new file mode 100644
index 0000000..a9bad83
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertTrue;
+
+public class NDArrayTest {
+    @Test
+    public void testCreateNDArray() {
+        NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f},
+                new Shape(new int[]{1, 3}),
+                new Context("cpu", 0));
+        int[] arr = new int[]{1, 3};
+        assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+        assertTrue(nd.at(0).at(0).toArray()[0] == 1.0f);
+        List<Float> list = Arrays.asList(1.0f, 2.0f, 3.0f);
+        // Second way creating NDArray
+        nd = NDArray.array(list,
+                new Shape(new int[]{1, 3}),
+                new Context("cpu", 0));
+        assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+    }
+
+    @Test
+    public void testZeroOneEmpty(){
+        NDArray ones = NDArray.ones(new Context("cpu", 0), new int[]{100, 100});
+        NDArray zeros = NDArray.zeros(new Context("cpu", 0), new int[]{100, 100});
+        NDArray empty = NDArray.empty(new Context("cpu", 0), new int[]{100, 100});
+        int[] arr = new int[]{100, 100};
+        assertTrue(Arrays.equals(ones.shape().toArray(), arr));
+        assertTrue(Arrays.equals(zeros.shape().toArray(), arr));
+        assertTrue(Arrays.equals(empty.shape().toArray(), arr));
+    }
+
+    @Test
+    public void testComparison(){
+        NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, new Shape(new int[]{3}), new Context("cpu", 0));
+        NDArray nd2 = new NDArray(new float[]{3.0f, 4.0f, 5.0f}, new Shape(new int[]{3}), new Context("cpu", 0));
+        nd = nd.add(nd2);
+        float[] greater = new float[]{1, 1, 1};
+        assertTrue(Arrays.equals(nd.greater(nd2).toArray(), greater));
+        nd = nd.subtract(nd2);
+        nd = nd.subtract(nd2);
+        float[] lesser = new float[]{0, 0, 0};
+        assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser));
+    }
+
+    @Test
+    public void testGenerated(){
+        NDArray$ NDArray = NDArray$.MODULE$;
+        float[] arr = new float[]{1.0f, 2.0f, 3.0f};
+        NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
+        float result = NDArray.norm(nd).invoke().get().toArray()[0];
+        float cal = 0.0f;
+        for (float ele : arr) {
+            cal += ele * ele;
+        }
+        cal = (float) Math.sqrt(cal);
+        assertTrue(Math.abs(result - cal) < 1e-5);
+        NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
+        NDArray.dot(nd, nd).setout(dotResult).invoke().get();
+        assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
+    }
+}
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
new file mode 100644
index 0000000..c530c73
--- /dev/null
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import org.apache.mxnet.init.Base._
+import org.apache.mxnet.utils.CToScalaUtils
+
+import scala.annotation.StaticAnnotation
+import scala.collection.mutable.ListBuffer
+import scala.language.experimental.macros
+import scala.reflect.macros.blackbox
+
+private[mxnet] class AddJNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation {
+  private[mxnet] def macroTransform(annottees: Any*) = macro JavaNDArrayMacro.typeSafeAPIDefs
+}
+
+private[mxnet] object JavaNDArrayMacro {
+  case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
+  case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
+
+  // scalastyle:off havetype
+  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+    typeSafeAPIImpl(c)(annottees: _*)
+  }
+  // scalastyle:off havetype
+
+  private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
+
+  private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
+    import c.universe._
+
+    val isContrib: Boolean = c.prefix.tree match {
+      case q"new AddJNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
+    }
+    // Defines Operators that should not generated
+    val notGenerated = Set("Custom")
+
+    val newNDArrayFunctions = {
+      if (isContrib) ndarrayFunctions.filter(
+        func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
+      else ndarrayFunctions.filterNot(_.name.startsWith("_"))
+    }.filterNot(ele => notGenerated.contains(ele.name)).groupBy(_.name.toLowerCase).map(ele => {
+      // Pattern matching for not generating depreciated method
+      if (ele._2.length == 1) ele._2.head
+      else {
+        if (ele._2.head.name.head.isLower) ele._2.head
+        else ele._2.last
+      }
+    })
+
+    val functionDefs = ListBuffer[DefDef]()
+    val classDefs = ListBuffer[ClassDef]()
+
+    newNDArrayFunctions.foreach { ndarrayfunction =>
+
+      // Construct argument field with all required args
+      var argDef = ListBuffer[String]()
+      // Construct Optional Arg
+      var OptionArgDef = ListBuffer[String]()
+      // Construct function Implementation field (e.g norm)
+      var impl = ListBuffer[String]()
+      impl += "val map = scala.collection.mutable.Map[String, Any]()"
+      // scalastyle:off
+      impl += "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]"
+      // scalastyle:on
+      // Construct Class Implementation (e.g normBuilder)
+      var classImpl = ListBuffer[String]()
+      ndarrayfunction.listOfArgs.foreach({ ndarrayArg =>
+        // var is a special word used to define variable in Scala,
+        // need to changed to something else in order to make it work
+        var currArgName = ndarrayArg.argName match {
+          case "var" => "vari"
+          case "type" => "typeOf"
+          case _ => ndarrayArg.argName
+        }
+        if (ndarrayArg.isOptional) {
+          OptionArgDef += s"private var $currArgName : ${ndarrayArg.argType} = null"
+          val tempDef = s"def set$currArgName($currArgName : ${ndarrayArg.argType})"
+          val tempImpl = s"this.$currArgName = $currArgName\nthis"
+          classImpl += s"$tempDef = {$tempImpl}"
+        } else {
+          argDef += s"$currArgName : ${ndarrayArg.argType}"
+        }
+        // NDArray arg implementation
+        val returnType = "org.apache.mxnet.javaapi.NDArray"
+        val base =
+          if (ndarrayArg.argType.equals(returnType)) {
+            s"args += this.$currArgName"
+          } else if (ndarrayArg.argType.equals(s"Array[$returnType]")){
+            s"this.$currArgName.foreach(args+=_)"
+          } else {
+            "map(\"" + ndarrayArg.argName + "\") = this." + currArgName
+          }
+        impl.append(
+          if (ndarrayArg.isOptional) s"if (this.$currArgName != null) $base"
+          else base
+        )
+      })
+      // add default out parameter
+      classImpl +=
+        "def setout(out : org.apache.mxnet.javaapi.NDArray) = {this.out = out\nthis}"
+      impl += "if (this.out != null) map(\"out\") = this.out"
+      OptionArgDef += "private var out : org.apache.mxnet.NDArray = null"
+      val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn"
+      // scalastyle:off
+      // Combine and build the function string
+      impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
+      val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")})"
+      val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}"
+      val classFinal = s"$classDef {$classBody}"
+      val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})"
+      val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})"
+      val functionFinal = s"$functionDef = $functionBody"
+      // scalastyle:on
+      functionDefs += c.parse(functionFinal).asInstanceOf[DefDef]
+      classDefs += c.parse(classFinal).asInstanceOf[ClassDef]
+    }
+
+    structGeneration(c)(functionDefs.toList, classDefs.toList, annottees : _*)
+  }
+
+  private def structGeneration(c: blackbox.Context)
+                              (funcDef : List[c.universe.DefDef],
+                               classDef : List[c.universe.ClassDef],
+                               annottees: c.Expr[Any]*)
+  : c.Expr[Any] = {
+    import c.universe._
+    val inputs = annottees.map(_.tree).toList
+    // pattern match on the inputs
+    var modDefs = inputs map {
+      case ClassDef(mods, name, something, template) =>
+        val q = template match {
+          case Template(superMaybe, emptyValDef, defs) =>
+            Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef)
+          case ex =>
+            throw new IllegalArgumentException(s"Invalid template: $ex")
+        }
+        ClassDef(mods, name, something, q)
+      case ModuleDef(mods, name, template) =>
+        val q = template match {
+          case Template(superMaybe, emptyValDef, defs) =>
+            Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef)
+          case ex =>
+            throw new IllegalArgumentException(s"Invalid template: $ex")
+        }
+        ModuleDef(mods, name, q)
+      case ex =>
+        throw new IllegalArgumentException(s"Invalid macro input: $ex")
+    }
+    //    modDefs ++= classDef
+    // wrap the result up in an Expr, and return it
+    val result = c.Expr(Block(modDefs, Literal(Constant())))
+    result
+  }
+
+  // List and add all the atomic symbol functions to current module.
+  private def initNDArrayModule(): List[NDArrayFunction] = {
+    val opNames = ListBuffer.empty[String]
+    _LIB.mxListAllOpNames(opNames)
+    opNames.map(opName => {
+      val opHandle = new RefLong
+      _LIB.nnGetOpHandle(opName, opHandle)
+      makeNDArrayFunction(opHandle.value, opName)
+    }).toList
+  }
+
+  // Create an atomic symbol function by handle and function name.
+  private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
+  : NDArrayFunction = {
+    val name = new RefString
+    val desc = new RefString
+    val keyVarNumArgs = new RefString
+    val numArgs = new RefInt
+    val argNames = ListBuffer.empty[String]
+    val argTypes = ListBuffer.empty[String]
+    val argDescs = ListBuffer.empty[String]
+
+    _LIB.mxSymbolGetAtomicSymbolInfo(
+      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
+    val argList = argNames zip argTypes map { case (argName, argType) =>
+      val typeAndOption =
+        CToScalaUtils.argumentCleaner(argName, argType,
+          "org.apache.mxnet.javaapi.NDArray", "javaapi.Shape")
+      new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
+    }
+    new NDArrayFunction(aliasName, argList.toList)
+  }
+}
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index d0ebe5b..48d8fdf 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -21,19 +21,19 @@ private[mxnet] object CToScalaUtils {
 
 
   // Convert C++ Types to Scala Types
-  def typeConversion(in : String, argType : String = "",
-                     argName : String, returnType : String) : String = {
+  def typeConversion(in : String, argType : String = "", argName : String,
+                     returnType : String, shapeType : String = "Shape") : String = {
     in match {
-      case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
+      case "Shape(tuple)" | "ShapeorNone" => s"org.apache.mxnet.$shapeType"
       case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
       case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
       => s"Array[$returnType]"
-      case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat"
-      case "int" | "intorNone" | "int(non-negative)" => "Int"
-      case "long" | "long(non-negative)" => "Long"
-      case "double" | "doubleorNone" => "Double"
+      case "float" | "real_t" | "floatorNone" => "java.lang.Float"
+      case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer"
+      case "long" | "long(non-negative)" => "java.lang.Long"
+      case "double" | "doubleorNone" => "java.lang.Double"
       case "string" => "String"
-      case "boolean" | "booleanorNone" => "Boolean"
+      case "boolean" | "booleanorNone" => "java.lang.Boolean"
       case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" => "Any"
       case default => throw new IllegalArgumentException(
         s"Invalid type for args: $default\nString argType: $argType\nargName: $argName")
@@ -52,8 +52,8 @@ private[mxnet] object CToScalaUtils {
     * @param argType Raw arguement Type description
     * @return (Scala_Type, isOptional)
     */
-  def argumentCleaner(argName: String,
-                      argType : String, returnType : String) : (String, Boolean) = {
+  def argumentCleaner(argName: String, argType : String,
+                      returnType : String, shapeType : String = "Shape") : (String, Boolean) = {
     val spaceRemoved = argType.replaceAll("\\s+", "")
     var commaRemoved : Array[String] = new Array[String](0)
     // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
@@ -73,7 +73,7 @@ private[mxnet] object CToScalaUtils {
         s"""expected "default=..." got ${commaRemoved(2)}""")
       (typeConversion(commaRemoved(0), argType, argName, returnType), true)
     } else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
-      val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
+      val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, shapeType)
       val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
       (tempType, tempOptional)
     } else {
diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
index c3a7c58..4404b08 100644
--- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
+++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
@@ -36,7 +36,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
     )
     val output = List(
       ("org.apache.mxnet.Symbol", true),
-      ("Int", false),
+      ("java.lang.Integer", false),
       ("org.apache.mxnet.Shape", true),
       ("String", true),
       ("Any", false)