You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/10/09 21:28:47 UTC

[incubator-mxnet] branch master updated: [MXNET-915] Java Inference API core wrappers and tests (#12757)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new bcbac41  [MXNET-915] Java Inference API core wrappers and tests (#12757)
bcbac41 is described below

commit bcbac41bce5c183186482aa7ae3d2675c28e2737
Author: Andrew Ayres <an...@gmail.com>
AuthorDate: Tue Oct 9 14:28:33 2018 -0700

    [MXNET-915] Java Inference API core wrappers and tests (#12757)
    
    * Core Java API class commit
    
    * Update ScalaStyle max line length to 132 instead of 100
---
 scala-package/core/pom.xml                         |  14 +++
 .../scala/org/apache/mxnet/javaapi/Context.scala   |  47 ++++++++
 .../scala/org/apache/mxnet/javaapi/DType.scala     |  27 +++++
 .../main/scala/org/apache/mxnet/javaapi/IO.scala   |  34 ++++++
 .../scala/org/apache/mxnet/javaapi/Shape.scala     |  52 +++++++++
 .../java/org/apache/mxnet/javaapi/ContextTest.java |  40 +++++++
 .../java/org/apache/mxnet/javaapi/DTypeTest.java   |  53 +++++++++
 .../test/java/org/apache/mxnet/javaapi/IOTest.java |  35 ++++++
 .../java/org/apache/mxnet/javaapi/ShapeTest.java   | 121 +++++++++++++++++++++
 9 files changed, 423 insertions(+)

diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index 0ee7494..ea3a2d6 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -82,6 +82,14 @@
         </configuration>
       </plugin>
       <plugin>
+        <groupId>org.apache.maven.plugins</groupId>
+        <artifactId>maven-surefire-plugin</artifactId>
+        <version>2.22.0</version>
+        <configuration>
+          <skipTests>false</skipTests>
+        </configuration>
+      </plugin>
+      <plugin>
         <groupId>org.scalastyle</groupId>
         <artifactId>scalastyle-maven-plugin</artifactId>
       </plugin>
@@ -105,6 +113,12 @@
       <scope>provided</scope>
     </dependency>
     <dependency>
+      <groupId>junit</groupId>
+      <artifactId>junit</artifactId>
+      <version>4.11</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
       <groupId>commons-io</groupId>
       <artifactId>commons-io</artifactId>
       <version>2.1</version>
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
new file mode 100644
index 0000000..5f0caed
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.javaapi
+
+import collection.JavaConverters._
+
+class Context(val context: org.apache.mxnet.Context) {
+
+  val deviceTypeid: Int = context.deviceTypeid
+
+  def this(deviceTypeName: String, deviceId: Int = 0)
+  = this(new org.apache.mxnet.Context(deviceTypeName, deviceId))
+
+  def withScope[T](body: => T): T = context.withScope(body)
+  def deviceType: String = context.deviceType
+
+  override def toString: String = context.toString
+  override def equals(other: Any): Boolean = context.equals(other)
+  override def hashCode: Int = context.hashCode
+}
+
+
+object Context {
+  implicit def fromContext(context: org.apache.mxnet.Context): Context = new Context(context)
+  implicit def toContext(jContext: Context): org.apache.mxnet.Context = jContext.context
+
+  val cpu: Context = org.apache.mxnet.Context.cpu()
+  val gpu: Context = org.apache.mxnet.Context.gpu()
+  val devtype2str = org.apache.mxnet.Context.devstr2type.asJava
+  val devstr2type = org.apache.mxnet.Context.devstr2type.asJava
+
+  def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala
new file mode 100644
index 0000000..e25cdde
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.javaapi
+
+object DType extends Enumeration {
+  type DType = org.apache.mxnet.DType.DType
+  val Float32 = org.apache.mxnet.DType.Float32
+  val Float64 = org.apache.mxnet.DType.Float64
+  val Float16 = org.apache.mxnet.DType.Float16
+  val UInt8 = org.apache.mxnet.DType.UInt8
+  val Int32 = org.apache.mxnet.DType.Int32
+  val Unknown = org.apache.mxnet.DType.Unknown
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
new file mode 100644
index 0000000..47b1c36
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+class DataDesc(val dataDesc: org.apache.mxnet.DataDesc) {
+
+  def this(name: String, shape: Shape, dType: DType.DType, layout: String) =
+    this(new org.apache.mxnet.DataDesc(name, shape, dType, layout))
+
+  override def toString(): String = dataDesc.toString()
+}
+
+object DataDesc{
+  implicit def fromDataDesc(dDesc: org.apache.mxnet.DataDesc): DataDesc = new DataDesc(dDesc)
+
+  implicit def toDataDesc(dataDesc: DataDesc): org.apache.mxnet.DataDesc = dataDesc.dataDesc
+
+  def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout))
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala
new file mode 100644
index 0000000..594e3a6
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import collection.JavaConverters._
+
+/**
+  * Shape of [[NDArray]] or other data
+  */
+
+class Shape(val shape: org.apache.mxnet.Shape) {
+  def this(dims: java.util.List[java.lang.Integer])
+    = this(new org.apache.mxnet.Shape(dims.asScala.map(Int.unbox)))
+  def this(dims: Array[Int]) = this(new org.apache.mxnet.Shape(dims))
+
+  def apply(dim: Int): Int = shape.apply(dim)
+  def get(dim: Int): Int = apply(dim)
+  def size: Int = shape.size
+  def length: Int = shape.length
+  def drop(dim: Int): Shape = shape.drop(dim)
+  def slice(from: Int, end: Int): Shape = shape.slice(from, end)
+  def product: Int = shape.product
+  def head: Int = shape.head
+
+  def toArray: Array[Int] = shape.toArray
+  def toVector: java.util.List[Int] = shape.toVector.asJava
+
+  override def toString(): String = shape.toString
+  override def equals(o: Any): Boolean = shape.equals(o)
+  override def hashCode(): Int = shape.hashCode()
+}
+
+object Shape {
+  implicit def fromShape(shape: org.apache.mxnet.Shape): Shape = new Shape(shape)
+
+  implicit def toShape(jShape: Shape): org.apache.mxnet.Shape = jShape.shape
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ContextTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ContextTest.java
new file mode 100644
index 0000000..abd4b5e
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ContextTest.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class ContextTest {
+	
+	@Test
+	public void testCPU() {
+		Context.cpu();
+	}
+	
+	@Test
+	public void testDefault() {
+		Context.defaultCtx();
+	}
+	
+	@Test
+	public void testConstructor() {
+		new Context("cpu", 0);
+	}
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/DTypeTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/DTypeTest.java
new file mode 100644
index 0000000..2e356ed
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/DTypeTest.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+
+public class DTypeTest {
+	
+	@Test
+	public void Float16Test() {
+		DType.Float16();
+	}
+	
+	@Test
+	public void Float32Test() {
+		DType.Float32();
+	}
+	
+	@Test
+	public void Float64Test() {
+		DType.Float64();
+	}
+	
+	@Test
+	public void UnknownTest() {
+		DType.Unknown();
+	}
+	
+	@Test
+	public void Int32Test() {
+		DType.Int32();
+	}
+	
+	@Test
+	public void UInt8Test() {
+		DType.UInt8();
+	}
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/IOTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/IOTest.java
new file mode 100644
index 0000000..f53b5c4
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/IOTest.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+
+public class IOTest {
+	
+	@Test
+	public void testConstructor() {
+		Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+		new DataDesc("data", inputShape, DType.Float32(), "NCHW");
+	}
+	
+	@Test
+	public void testgetBatchAxis() {
+		DataDesc.getBatchAxis("NCHW");
+	}
+	
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ShapeTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ShapeTest.java
new file mode 100644
index 0000000..8f045b5
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ShapeTest.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import org.junit.Test;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+public class ShapeTest {
+	@Test
+	public void testArrayConstructor()
+	{
+		new Shape(new int[] {3, 4, 5});
+	}
+	
+	@Test
+	public void testListConstructor()
+	{
+		ArrayList<Integer> arrList = new ArrayList<Integer>();
+		arrList.add(3);
+		arrList.add(4);
+		arrList.add(5);
+		new Shape(arrList);
+	}
+	
+	@Test
+	public void testApply()
+	{
+		Shape jS = new Shape(new int[] {3, 4, 5});
+		assertEquals(jS.apply(1), 4);
+	}
+	
+	@Test
+	public void testGet()
+	{
+		Shape jS = new Shape(new int[] {3, 4, 5});
+		assertEquals(jS.get(1), 4);
+	}
+	
+	@Test
+	public void testSize()
+	{
+		Shape jS = new Shape(new int[] {3, 4, 5});
+		assertEquals(jS.size(), 3);
+	}
+	
+	@Test
+	public void testLength()
+	{
+		Shape jS = new Shape(new int[] {3, 4, 5});
+		assertEquals(jS.length(), 3);
+	}
+	
+	@Test
+	public void testDrop()
+	{
+		Shape jS = new Shape(new int[] {3, 4, 5});
+		ArrayList<Integer> l = new ArrayList<Integer>();
+		l.add(4);
+		l.add(5);
+		assertTrue(jS.drop(1).toVector().equals(l));
+	}
+	
+	@Test
+	public void testSlice()
+	{
+		Shape jS = new Shape(new int[] {3, 4, 5});
+		ArrayList<Integer> l = new ArrayList<Integer>();
+		l.add(4);
+		assertTrue(jS.slice(1,2).toVector().equals(l));
+	}
+	
+	@Test
+	public void testProduct()
+	{
+		Shape jS = new Shape(new int[] {3, 4, 5});
+		assertEquals(jS.product(), 60);
+	}
+	
+	@Test
+	public void testHead()
+	{
+		Shape jS = new Shape(new int[] {3, 4, 5});
+		assertEquals(jS.head(), 3);
+	}
+	
+	@Test
+	public void testToArray()
+	{
+		Shape jS = new Shape(new int[] {3, 4, 5});
+		assertTrue(Arrays.equals(jS.toArray(), new int[] {3,4,5}));
+	}
+	
+	@Test
+	public void testToVector()
+	{
+		Shape jS = new Shape(new int[] {3, 4, 5});
+		ArrayList<Integer> l = new ArrayList<Integer>();
+		l.add(3);
+		l.add(4);
+		l.add(5);
+		assertTrue(jS.toVector().equals(l));
+	}
+}