You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/09 18:42:25 UTC

[GitHub] lanking520 closed pull request #12536: [MXNET-913] Java API --- Scala NDArray Improvement

lanking520 closed pull request #12536: [MXNET-913] Java API --- Scala NDArray Improvement
URL: https://github.com/apache/incubator-mxnet/pull/12536
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 9b6a7dc6654..dc1273315b2 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -248,6 +248,11 @@ object NDArray extends NDArrayBase {
 
   def ones(ctx: Context, shape: Int *): NDArray = ones(Shape(shape: _*), ctx)
 
+  // Java compatible conversion methods
+  def empty(shape: Array[Int]): NDArray = empty(Shape(shape))
+  def zeros(shape: Array[Int]): NDArray = zeros(Shape(shape))
+  def ones(shape: Array[Int]) : NDArray = ones(Shape(shape))
+
   /**
    * Create a new NDArray filled with given value, with specified shape.
    * @param shape shape of the NDArray.
@@ -567,6 +572,18 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArrayCollector.collect(this)
   }
 
+  /**
+    * Java Flavor creating new NDArray
+    * @param arr
+    * @param shape
+    * @param ctx
+    * @return
+    */
+  def this(arr : Array[Float], shape : Shape, ctx : Context) = {
+    this(NDArray.newAllocHandle(shape, ctx, delayAlloc = false, Base.MX_REAL_TYPE))
+    this.set(arr)
+  }
+
   // record arrays who construct this array instance
   // we use weak reference to prevent gc blocking
   private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
@@ -940,6 +957,44 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
+  /* Java Compatibility Functions
+     Function name with underscore means
+     it is going to do the operator as well as
+     update itself such as +=
+   */
+  def add(other : NDArray) : NDArray = this + other
+  def add(other : Float) : NDArray = this + other
+  def _add(other : NDArray) : NDArray = this += other
+  def _add(other : Float) : NDArray = this += other
+  def subtract(other : NDArray) : NDArray = this - other
+  def subtract(other : Float) : NDArray = this - other
+  def _subtract(other : NDArray) : NDArray = this -= other
+  def _subtract(other : Float) : NDArray = this -= other
+  def multiply(other : NDArray) : NDArray = this * other
+  def multiply(other : Float) : NDArray = this * other
+  def _multiply(other : NDArray) : NDArray = this *= other
+  def _multiply(other : Float) : NDArray = this *= other
+  def div(other : NDArray) : NDArray = this / other
+  def div(other : Float) : NDArray = this / other
+  def _div(other : NDArray) : NDArray = this /= other
+  def _div(other : Float) : NDArray = this /= other
+  def pow(other : NDArray) : NDArray = this ** other
+  def pow(other : Float) : NDArray = this ** other
+  def _pow(other : NDArray) : NDArray = this **= other
+  def _pow(other : Float) : NDArray = this **= other
+  def mod(other : NDArray) : NDArray = this % other
+  def mod(other : Float) : NDArray = this % other
+  def _mod(other : NDArray) : NDArray = this %= other
+  def _mod(other : Float) : NDArray = this %= other
+  def greater(other : NDArray) : NDArray = this > other
+  def greater(other : Float) : NDArray = this > other
+  def greaterEqual(other : NDArray) : NDArray = this >= other
+  def greaterEqual(other : Float) : NDArray = this >= other
+  def lesser(other : NDArray) : NDArray = this < other
+  def lesser(other : Float) : NDArray = this < other
+  def lesserEqual(other : NDArray) : NDArray = this <= other
+  def lesserEqual(other : Float) : NDArray = this <= other
+
   /**
    * Return a copied flat java array of current array (row-major).
    * @return  A copy of array content.
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
index 68917621772..772ec8e262c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
@@ -29,6 +29,15 @@ class Shape(dims: Traversable[Int]) extends Serializable {
     this(dims.toVector)
   }
 
+  /**
+    * Java compatible constructor
+    * @param dims Array of Int input
+    * @return Shape
+    */
+  def this(dims: Array[Int]) = {
+    this(dims.toVector)
+  }
+
   def apply(dim: Int): Int = shape(dim)
   def get(dim: Int): Int = apply(dim)
   def size: Int = shape.size
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/api/java/ArgBuilder.scala b/scala-package/core/src/main/scala/org/apache/mxnet/api/java/ArgBuilder.scala
new file mode 100644
index 00000000000..4b478e5fc5d
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/api/java/ArgBuilder.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.api.java
+
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+import collection.JavaConverters._
+
+/**
+  * This arg Builder is intent to solve Java to Scala conversion
+  * to take the input such as (arg: Any*)
+  */
+class ArgBuilder {
+  private var data = ListBuffer[Any]()
+  private var map = mutable.Map[String, Any]()
+
+  def addArg(anyRef: AnyRef): ArgBuilder = {
+    require(map.isEmpty,
+      "Map is not empty, please do either key-value or positional-arg but not both")
+    this.data += anyRef.asInstanceOf[Any]
+    this
+  }
+
+  def addArg(key : String, value : AnyRef) : ArgBuilder = {
+    require(data.isEmpty,
+      "Data is not empty, please do either key-value or positional-arg but not both")
+    this.map(key) = value.asInstanceOf[Any]
+    this
+  }
+
+  def addBatchArgs(list : java.util.List[AnyRef]) : ArgBuilder = {
+    require(map.isEmpty,
+      "Map is not empty, please do either key-value or positional-arg but not both")
+    for (i <- 0 to list.size()) {
+      this.data += list.get(i)
+    }
+    this
+  }
+
+  def addBatchArgs(arr : Array[AnyRef]) : ArgBuilder = {
+    require(map.isEmpty,
+      "Map is not empty, please do either key-value or positional-arg but not both")
+    arr.foreach(ele => this.data += ele)
+    this
+  }
+
+  def buildMap() : Map[String, Any] = {
+    this.map.toMap
+  }
+
+  def buildSeq() : Seq[Any] = {
+    this.data
+  }
+}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 5d88bb39e50..edbc98e5a11 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -38,6 +38,11 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     assert(ndones.toScalar === 1f)
   }
 
+  test("new NDArray") {
+    val ndarray = new NDArray(Array(1.0f, 2.0f), Shape(1, 2), Context.cpu())
+    assert(ndarray.shape == Shape(1, 2))
+  }
+
   test ("call toScalar on an ndarray which is not a scalar") {
     intercept[Exception] { NDArray.zeros(1, 1).toScalar }
   }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services