You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/10/09 21:28:47 UTC
[incubator-mxnet] branch master updated: [MXNET-915] Java Inference
API core wrappers and tests (#12757)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new bcbac41 [MXNET-915] Java Inference API core wrappers and tests (#12757)
bcbac41 is described below
commit bcbac41bce5c183186482aa7ae3d2675c28e2737
Author: Andrew Ayres <an...@gmail.com>
AuthorDate: Tue Oct 9 14:28:33 2018 -0700
[MXNET-915] Java Inference API core wrappers and tests (#12757)
* Core Java API class commit
* Update ScalaStyle max line length to 132 instead of 100
---
scala-package/core/pom.xml | 14 +++
.../scala/org/apache/mxnet/javaapi/Context.scala | 47 ++++++++
.../scala/org/apache/mxnet/javaapi/DType.scala | 27 +++++
.../main/scala/org/apache/mxnet/javaapi/IO.scala | 34 ++++++
.../scala/org/apache/mxnet/javaapi/Shape.scala | 52 +++++++++
.../java/org/apache/mxnet/javaapi/ContextTest.java | 40 +++++++
.../java/org/apache/mxnet/javaapi/DTypeTest.java | 53 +++++++++
.../test/java/org/apache/mxnet/javaapi/IOTest.java | 35 ++++++
.../java/org/apache/mxnet/javaapi/ShapeTest.java | 121 +++++++++++++++++++++
9 files changed, 423 insertions(+)
diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index 0ee7494..ea3a2d6 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -82,6 +82,14 @@
</configuration>
</plugin>
<plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <version>2.22.0</version>
+ <configuration>
+ <skipTests>false</skipTests>
+ </configuration>
+ </plugin>
+ <plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
</plugin>
@@ -105,6 +113,12 @@
<scope>provided</scope>
</dependency>
<dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>4.11</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.1</version>
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
new file mode 100644
index 0000000..5f0caed
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.javaapi
+
+import collection.JavaConverters._
+
+class Context(val context: org.apache.mxnet.Context) {
+
+ val deviceTypeid: Int = context.deviceTypeid
+
+ def this(deviceTypeName: String, deviceId: Int = 0)
+ = this(new org.apache.mxnet.Context(deviceTypeName, deviceId))
+
+ def withScope[T](body: => T): T = context.withScope(body)
+ def deviceType: String = context.deviceType
+
+ override def toString: String = context.toString
+ override def equals(other: Any): Boolean = context.equals(other)
+ override def hashCode: Int = context.hashCode
+}
+
+
+object Context {
+ implicit def fromContext(context: org.apache.mxnet.Context): Context = new Context(context)
+ implicit def toContext(jContext: Context): org.apache.mxnet.Context = jContext.context
+
+ val cpu: Context = org.apache.mxnet.Context.cpu()
+ val gpu: Context = org.apache.mxnet.Context.gpu()
+ val devtype2str = org.apache.mxnet.Context.devstr2type.asJava
+ val devstr2type = org.apache.mxnet.Context.devstr2type.asJava
+
+ def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala
new file mode 100644
index 0000000..e25cdde
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/DType.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.javaapi
+
+object DType extends Enumeration {
+ type DType = org.apache.mxnet.DType.DType
+ val Float32 = org.apache.mxnet.DType.Float32
+ val Float64 = org.apache.mxnet.DType.Float64
+ val Float16 = org.apache.mxnet.DType.Float16
+ val UInt8 = org.apache.mxnet.DType.UInt8
+ val Int32 = org.apache.mxnet.DType.Int32
+ val Unknown = org.apache.mxnet.DType.Unknown
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
new file mode 100644
index 0000000..47b1c36
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+class DataDesc(val dataDesc: org.apache.mxnet.DataDesc) {
+
+ def this(name: String, shape: Shape, dType: DType.DType, layout: String) =
+ this(new org.apache.mxnet.DataDesc(name, shape, dType, layout))
+
+ override def toString(): String = dataDesc.toString()
+}
+
+object DataDesc{
+ implicit def fromDataDesc(dDesc: org.apache.mxnet.DataDesc): DataDesc = new DataDesc(dDesc)
+
+ implicit def toDataDesc(dataDesc: DataDesc): org.apache.mxnet.DataDesc = dataDesc.dataDesc
+
+ def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout))
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala
new file mode 100644
index 0000000..594e3a6
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import collection.JavaConverters._
+
+/**
+ * Shape of [[NDArray]] or other data
+ */
+
+class Shape(val shape: org.apache.mxnet.Shape) {
+ def this(dims: java.util.List[java.lang.Integer])
+ = this(new org.apache.mxnet.Shape(dims.asScala.map(Int.unbox)))
+ def this(dims: Array[Int]) = this(new org.apache.mxnet.Shape(dims))
+
+ def apply(dim: Int): Int = shape.apply(dim)
+ def get(dim: Int): Int = apply(dim)
+ def size: Int = shape.size
+ def length: Int = shape.length
+ def drop(dim: Int): Shape = shape.drop(dim)
+ def slice(from: Int, end: Int): Shape = shape.slice(from, end)
+ def product: Int = shape.product
+ def head: Int = shape.head
+
+ def toArray: Array[Int] = shape.toArray
+ def toVector: java.util.List[Int] = shape.toVector.asJava
+
+ override def toString(): String = shape.toString
+ override def equals(o: Any): Boolean = shape.equals(o)
+ override def hashCode(): Int = shape.hashCode()
+}
+
+object Shape {
+ implicit def fromShape(shape: org.apache.mxnet.Shape): Shape = new Shape(shape)
+
+ implicit def toShape(jShape: Shape): org.apache.mxnet.Shape = jShape.shape
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ContextTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ContextTest.java
new file mode 100644
index 0000000..abd4b5e
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ContextTest.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class ContextTest {
+
+ @Test
+ public void testCPU() {
+ Context.cpu();
+ }
+
+ @Test
+ public void testDefault() {
+ Context.defaultCtx();
+ }
+
+ @Test
+ public void testConstructor() {
+ new Context("cpu", 0);
+ }
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/DTypeTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/DTypeTest.java
new file mode 100644
index 0000000..2e356ed
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/DTypeTest.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+
+public class DTypeTest {
+
+ @Test
+ public void Float16Test() {
+ DType.Float16();
+ }
+
+ @Test
+ public void Float32Test() {
+ DType.Float32();
+ }
+
+ @Test
+ public void Float64Test() {
+ DType.Float64();
+ }
+
+ @Test
+ public void UnknownTest() {
+ DType.Unknown();
+ }
+
+ @Test
+ public void Int32Test() {
+ DType.Int32();
+ }
+
+ @Test
+ public void UInt8Test() {
+ DType.UInt8();
+ }
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/IOTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/IOTest.java
new file mode 100644
index 0000000..f53b5c4
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/IOTest.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+
+public class IOTest {
+
+ @Test
+ public void testConstructor() {
+ Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+ new DataDesc("data", inputShape, DType.Float32(), "NCHW");
+ }
+
+ @Test
+ public void testgetBatchAxis() {
+ DataDesc.getBatchAxis("NCHW");
+ }
+
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ShapeTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ShapeTest.java
new file mode 100644
index 0000000..8f045b5
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ShapeTest.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import org.junit.Test;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+public class ShapeTest {
+ @Test
+ public void testArrayConstructor()
+ {
+ new Shape(new int[] {3, 4, 5});
+ }
+
+ @Test
+ public void testListConstructor()
+ {
+ ArrayList<Integer> arrList = new ArrayList<Integer>();
+ arrList.add(3);
+ arrList.add(4);
+ arrList.add(5);
+ new Shape(arrList);
+ }
+
+ @Test
+ public void testApply()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.apply(1), 4);
+ }
+
+ @Test
+ public void testGet()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.get(1), 4);
+ }
+
+ @Test
+ public void testSize()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.size(), 3);
+ }
+
+ @Test
+ public void testLength()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.length(), 3);
+ }
+
+ @Test
+ public void testDrop()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ ArrayList<Integer> l = new ArrayList<Integer>();
+ l.add(4);
+ l.add(5);
+ assertTrue(jS.drop(1).toVector().equals(l));
+ }
+
+ @Test
+ public void testSlice()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ ArrayList<Integer> l = new ArrayList<Integer>();
+ l.add(4);
+ assertTrue(jS.slice(1,2).toVector().equals(l));
+ }
+
+ @Test
+ public void testProduct()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.product(), 60);
+ }
+
+ @Test
+ public void testHead()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.head(), 3);
+ }
+
+ @Test
+ public void testToArray()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertTrue(Arrays.equals(jS.toArray(), new int[] {3,4,5}));
+ }
+
+ @Test
+ public void testToVector()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ ArrayList<Integer> l = new ArrayList<Integer>();
+ l.add(3);
+ l.add(4);
+ l.add(5);
+ assertTrue(jS.toVector().equals(l));
+ }
+}