You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/03/06 07:15:56 UTC

[incubator-mxnet] branch master updated: Add default parameters for Scala NDArray.arange (#13816)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a0f3f92   Add default parameters for Scala NDArray.arange (#13816)
a0f3f92 is described below

commit a0f3f92f81f55b0bc9dcd74d54b8e76d98a17ea6
Author: Dang Trung Kien <ki...@pm.me>
AuthorDate: Wed Mar 6 15:15:32 2019 +0800

     Add default parameters for Scala NDArray.arange (#13816)
    
    * Add default arguments for arange
    
    * Remove redundant tag
    
    * Update test
    
    * Remove infer_range for python ndarray.arange
    
    * Update CONTRIBUTORS.md
    
    * Deprecate infer_range argument in ndarray.arange
---
 CONTRIBUTORS.md                                      |  1 +
 python/mxnet/ndarray/ndarray.py                      |  7 +++++--
 .../src/main/scala/org/apache/mxnet/NDArray.scala    |  8 +++-----
 .../test/scala/org/apache/mxnet/NDArraySuite.scala   | 20 +++++++++++++++-----
 4 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index caf61e8..0a7eb42 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -211,6 +211,7 @@ List of Contributors
 * [Harsh Patel](https://github.com/harshp8l)
 * [Xiao Wang](https://github.com/BeyonderXX)
 * [Piyush Ghai](https://github.com/piyushghai)
+* [Dang Trung Kien](https://github.com/kiendang)
 * [Zach Boldyga](https://github.com/zboldyga)
 * [Gordon Reid](https://github.com/gordon1992)
 
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index fb329f1..351c013 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2544,7 +2544,7 @@ def moveaxis(tensor, source, destination):
 
 
 # pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name
-def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t):
+def arange(start, stop=None, step=1.0, repeat=1, infer_range=None, ctx=None, dtype=mx_real_t):
     """Returns evenly spaced values within a given interval.
 
     Values are generated within the half-open interval [`start`, `stop`). In other
@@ -2588,10 +2588,13 @@ def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dt
     >>> mx.nd.arange(2, 6, step=2, repeat=3, dtype='int32').asnumpy()
     array([2, 2, 2, 4, 4, 4], dtype=int32)
     """
+    if infer_range is not None:
+        warnings.warn('`infer_range` argument has been deprecated',
+                      DeprecationWarning)
     if ctx is None:
         ctx = current_context()
     return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
-                             infer_range=infer_range, dtype=dtype, ctx=str(ctx))
+                             infer_range=False, dtype=dtype, ctx=str(ctx))
 # pylint: enable= no-member, protected-access, too-many-arguments
 
 
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index ca2e986..915e4c6 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -575,15 +575,13 @@ object NDArray extends NDArrayBase {
    * @param stop End of interval.
    * @param step Spacing between values. The default step size is 1.
    * @param repeat Number of times to repeat each element. The default repeat count is 1.
-   * @param infer_range
-   *          When set to True, infer the stop position from the start, step,
-   *          repeat, and output tensor size.
    * @param ctx Device context. Default context is the current default context.
    * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
    * @return NDArray of evenly spaced values in the specified range.
    */
-  def arange(start: Float, stop: Option[Float], step: Float,
-             repeat: Int, ctx: Context, dType: DType): NDArray = {
+  def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
+             repeat: Int = 1, ctx: Context = Context.defaultCtx,
+             dType: DType = Base.MX_REAL_TYPE): NDArray = {
     val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
       "infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString())
     val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 054300e..72a5974 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -340,11 +340,21 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
       val stop = start + scala.util.Random.nextFloat() * 100
       val step = scala.util.Random.nextFloat() * 4
       val repeat = 1
-      val result = (start.toDouble until stop.toDouble by step.toDouble)
-              .flatMap(x => Array.fill[Float](repeat)(x.toFloat))
-      val range = NDArray.arange(start = start, stop = Some(stop), step = step,
-        repeat = repeat, ctx = Context.cpu(), dType = DType.Float32)
-      assert(CheckUtils.reldiff(result.toArray, range.toArray) <= 1e-4f)
+
+      val result1 = (start.toDouble until stop.toDouble by step.toDouble)
+        .flatMap(x => Array.fill[Float](repeat)(x.toFloat))
+      val range1 = NDArray.arange(start = start, stop = Some(stop), step = step,
+        repeat = repeat)
+      assert(CheckUtils.reldiff(result1.toArray, range1.toArray) <= 1e-4f)
+
+      val result2 = (0.0 until stop.toDouble by step.toDouble)
+        .flatMap(x => Array.fill[Float](repeat)(x.toFloat))
+      val range2 = NDArray.arange(stop, step = step, repeat = repeat)
+      assert(CheckUtils.reldiff(result2.toArray, range2.toArray) <= 1e-4f)
+
+      val result3 = 0f to stop by 1f
+      val range3 = NDArray.arange(stop)
+      assert(CheckUtils.reldiff(result3.toArray, range3.toArray) <= 1e-4f)
     }
   }