You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/08/24 16:00:23 UTC

[incubator-mxnet] branch master updated: Allow stop of arange to be inferred from dims. (#12064)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7bfe427  Allow stop of arange to be inferred from dims. (#12064)
7bfe427 is described below

commit 7bfe42786f79c3214b367aa9ef756f9e3f0eb132
Author: Taliesin Beynon <ta...@gmail.com>
AuthorDate: Fri Aug 24 18:00:13 2018 +0200

    Allow stop of arange to be inferred from dims. (#12064)
    
    * Allow stop of arange to be inferred from dims.
    
    Enabled via a flag.
    
    * modify NDArray/Symbol to add infer_range param
    
    * Add test for arange-with-inference.
    
    * Add a comment to readme about JDK 8.
    
    * Fix approx=.
    
    Include a test of this fix as well.
---
 contrib/clojure-package/README.md                  |  4 +++-
 .../src/org/apache/clojure_mxnet/ndarray.clj       |  2 +-
 .../src/org/apache/clojure_mxnet/symbol.clj        | 12 ++++++++++-
 .../org/apache/clojure_mxnet/operator_test.clj     | 11 ++++++++++
 .../test/org/apache/clojure_mxnet/test_util.clj    |  6 ++++--
 .../test/org/apache/clojure_mxnet/util_test.clj    |  8 +++++++
 python/mxnet/ndarray/ndarray.py                    |  4 ++--
 python/mxnet/symbol/symbol.py                      |  4 ++--
 .../src/main/scala/org/apache/mxnet/NDArray.scala  |  9 ++++----
 .../src/main/scala/org/apache/mxnet/Symbol.scala   | 25 +++++++++++++++++++---
 src/operator/tensor/init_op.h                      | 10 ++++++++-
 tests/python/unittest/test_operator.py             |  8 +++++++
 12 files changed, 85 insertions(+), 18 deletions(-)

diff --git a/contrib/clojure-package/README.md b/contrib/clojure-package/README.md
index 5e7356c..ea678cc 100644
--- a/contrib/clojure-package/README.md
+++ b/contrib/clojure-package/README.md
@@ -107,7 +107,9 @@ The jars from maven with the needed MXNet native binaries in it. On startup, the
 
 ### Build from MXNET Source
 
-Checkout the latest sha from the main package
+First, ensure you have JDK 8 on your system. Later versions may produce cryptic build errors mentioning `scala.reflect.internal.MissingRequirementError`. 
+
+Checkout the latest SHA from the main package:
 
 `git clone --recursive https://github.com/apache/incubator-mxnet.git ~/mxnet`
 `cd ~/mxnet`
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
index e37a8bc..7ca4ede 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
@@ -89,7 +89,7 @@
    (NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
   ([start stop]
    (arange start stop {})))
-
+  
 (defn slice
   "Return a sliced NDArray that shares memory with current one."
   ([ndarray i]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
index 42ae034..12135fb 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
@@ -135,10 +135,20 @@
   ([start stop  {:keys [step repeat dtype]
                  :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
                  :as opts}]
-   (Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype))
+   (Symbol/arange (float start) ($/option (float stop)) step repeat false nil dtype))
   ([start stop]
    (arange start stop {})))
 
+(defn arange-with-inference
+  "Behaves like arange operator, but infers the stop value from the output shape, 
+   which must be known from the rest of the net."
+  ([start {:keys [step repeat dtype]
+           :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
+          :as opts}]
+   (Symbol/arange (float start) ($/option nil) step repeat true nil dtype))
+  ([start]
+   (arange-with-inference start {})))
+
 ;;; manually defined because of a conflicting arity of 2 with the auto-gen
 (defn min
   ([sym-name kwargs-map symbol-list kwargs-map-1]
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
index a71a312..1b4b2ea 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
@@ -222,6 +222,17 @@
     (is (= 0 (count (executor/grad-arrays exec))))
     (is (approx= 1e-4 result (-> (executor/outputs exec) (first))))))
 
+(deftest test-arange-with-inference
+  (let [arange (sym/arange-with-inference 0)
+        data (sym/variable "data")
+        added (sym/+ arange data)
+        result (range 0 4)
+        data-tmp (ndarray/zeros [4])
+        exec (sym/bind added (context/default-context) {"data" data-tmp})]
+    (executor/forward exec)
+    (is (= 0 (count (executor/grad-arrays exec))))
+    (is (approx= 1e-4 result (-> (executor/outputs exec) (first))))))
+
 (deftest test-scalar-pow
   (let [data (sym/variable "data")
         shape-vec [1 1]
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
index dcdbea6..ecd54ca 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
@@ -22,6 +22,8 @@
   (if (and (number? x) (number? y))
     (let [diff (Math/abs (- x y))]
       (< diff tolerance))
-    (reduce (fn [x y] (and x y))
-            (map #(approx= tolerance %1 %2) x y))))
+    (and
+    	(= (count x) (count y))
+		(reduce (fn [x y] (and x y))
+            (map #(approx= tolerance %1 %2) x y)))))
 
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index 5551fab..de34808 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -21,6 +21,7 @@
             [org.apache.clojure-mxnet.util :as util]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
             [org.apache.clojure-mxnet.symbol :as sym]
+            [org.apache.clojure-mxnet.test-util :as test-util]
             [clojure.spec.alpha :as s])
   (:import (org.apache.mxnet Shape NDArrayFuncReturn NDArray)
            (scala.collection Map Set)
@@ -183,3 +184,10 @@
 (deftest test-validate
   (is (nil? (util/validate! string? "foo" "Not a string!")))
   (is (thrown-with-msg? Exception #"Not a string!" (util/validate! ::x 1 "Not a string!"))))
+
+(deftest test-approx=
+  (let [data1 [1 1 1 1]
+        data2 [1 1 1 1 9 9 9 9]
+        data3 [1 1 1 2]]
+    (is (not (test-util/approx= 1e-9 data1 data2)))
+    (is (test-util/approx= 2 data1 data3))))
\ No newline at end of file
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 46b21a9..d6d619f 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2475,7 +2475,7 @@ def moveaxis(tensor, source, destination):
 
 
 # pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name
-def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
+def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t):
     """Returns evenly spaced values within a given interval.
 
     Values are generated within the half-open interval [`start`, `stop`). In other
@@ -2519,7 +2519,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
     if ctx is None:
         ctx = current_context()
     return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
-                             dtype=dtype, ctx=str(ctx))
+                             infer_range=infer_range, dtype=dtype, ctx=str(ctx))
 # pylint: enable= no-member, protected-access, too-many-arguments
 
 
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 5f6cbd6..da5533f 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -2886,7 +2886,7 @@ def full(shape, val, dtype=None, **kwargs):
     return _internal._full(shape=shape, dtype=dtype, value=float(val), **kwargs)
 
 # pylint: disable=redefined-outer-name
-def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
+def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):
     """Returns evenly spaced values within a given interval.
 
     Parameters
@@ -2911,7 +2911,7 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
     if dtype is None:
         dtype = _numpy.float32
     return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
-                             name=name, dtype=dtype)
+                             infer_range=infer_range, name=name, dtype=dtype)
 
 def histogram(a, bins=10, range=None, **kwargs):
     """Compute the histogram of the input data.
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 548c30b..8b5e1e0 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -407,11 +407,10 @@ object NDArray extends NDArrayBase {
    * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
    * @return NDArray of evenly spaced values in the specified range.
    */
-  def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
-    repeat: Int = 1, ctx: Context = Context.defaultCtx,
-    dType: DType = Base.MX_REAL_TYPE): NDArray = {
-    val params = Map("start" -> start, "step" -> step,
-      "repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString())
+  def arange(start: Float, stop: Option[Float], step: Float,
+             repeat: Int, ctx: Context, dType: DType): NDArray = {
+    val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
+      "infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString())
     val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
     NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0)
   }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 194d368..e3e1a32 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -955,9 +955,28 @@ object Symbol extends SymbolBase {
    * @return Symbol The created Symbol.
    */
   def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
-    repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
-    val params = Map("start" -> start, "step" -> step,
-      "repeat" -> repeat, "dtype" -> dType.toString())
+             repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
+    arange(start, stop, step, repeat, infer_range = false, name, dType)
+  }
+
+  /**
+   * Returns evenly spaced values within a given interval.
+   * stop value can be infered from the output shape,
+   * which must be known from the rest of the net.
+   * @param start Start of interval. The default start value is 0.
+   * @param stop End of interval.
+   * @param step Spacing between values. The default step size is 1.
+   * @param repeat Number of times to repeat each element. The default repeat count is 1.
+   * @param infer_range Infer the stop value from output shape
+   * @param ctx Device context. Default context is the current default context.
+   * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
+   * @return NDArray of evenly spaced values in the specified range.
+   */
+  def arange(start: Float, stop: Option[Float], step: Float,
+             repeat: Int, infer_range: Boolean, name: String,
+             dType: DType): Symbol = {
+    val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
+      "infer_range" -> infer_range, "dtype" -> dType.toString())
     val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
     createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams)
   }
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 4af3a40..304911a 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -123,6 +123,7 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
   dmlc::optional<double> stop;
   double step;
   int repeat;
+  bool infer_range;
   std::string ctx;
   int dtype;
   DMLC_DECLARE_PARAMETER(RangeParam) {
@@ -140,6 +141,10 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
     .set_default(1)
     .describe("The repeating time of all elements."
               " E.g repeat=3, the element a will be repeated three times --> a, a, a.");
+    DMLC_DECLARE_FIELD(infer_range)
+    .set_default(false)
+    .describe("Whether to infer the stop position from the start, step, repeat, and output tensor"
+              "size.");
     DMLC_DECLARE_FIELD(ctx)
     .set_default("")
     .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
@@ -176,7 +181,7 @@ struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> {
 inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
   RangeParam param;
   param.Init(attrs->dict);
-  if (!static_cast<bool>(param.stop)) {
+  if (!static_cast<bool>(param.infer_range) && !static_cast<bool>(param.stop)) {
     param.stop = param.start;
     param.start = 0;
   }
@@ -471,6 +476,9 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
     << "Range does not support step=0, received " << param.step;
   CHECK(param.repeat > 0)
     << "Range only supports repeat > 0, received " << param.repeat;
+  if (param.infer_range && !param.stop.has_value()) {
+    return false;
+  }
   if (param.step > 0) {
     CHECK(param.start < param.stop.value())
       << "Invalid range (start, stop, step) = "
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index fc6b814..fd60611 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3646,10 +3646,18 @@ def test_init():
                 nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype)
                 assert_almost_equal(np_out, nd_out.asnumpy())
 
+    def test_arange_inferstop():
+        s = mx.sym.arange(start=0, stop=None, infer_range=True)
+        s = mx.sym.elemwise_add(s, mx.sym.zeros(shape=[5]))
+        exe = s.bind(ctx=mx.cpu(), args={})
+        exe.forward()
+        assert_almost_equal(exe.outputs[0].asnumpy(), np.array([0,1,2,3,4]))
+
     test_basic_val_init(mx.sym.zeros, np.zeros, (3, 4), np.float32)
     test_basic_val_init(mx.sym.ones, np.ones, 3, np.int32)
     test_basic_val_init(mx.sym.ones, np.ones, (2, 2, 3), np.float16)
     test_arange()
+    test_arange_inferstop()
 
 
 @with_seed()