You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/08/24 16:00:23 UTC
[incubator-mxnet] branch master updated: Allow stop of arange to be
inferred from dims. (#12064)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7bfe427 Allow stop of arange to be inferred from dims. (#12064)
7bfe427 is described below
commit 7bfe42786f79c3214b367aa9ef756f9e3f0eb132
Author: Taliesin Beynon <ta...@gmail.com>
AuthorDate: Fri Aug 24 18:00:13 2018 +0200
Allow stop of arange to be inferred from dims. (#12064)
* Allow stop of arange to be inferred from dims.
Enabled via a flag.
* modify NDArray/Symbol to add infer_range param
* Add test for arange-with-inference.
* Add a comment to readme about JDK 8.
* Fix approx=.
Include a test of this fix as well.
---
contrib/clojure-package/README.md | 4 +++-
.../src/org/apache/clojure_mxnet/ndarray.clj | 2 +-
.../src/org/apache/clojure_mxnet/symbol.clj | 12 ++++++++++-
.../org/apache/clojure_mxnet/operator_test.clj | 11 ++++++++++
.../test/org/apache/clojure_mxnet/test_util.clj | 6 ++++--
.../test/org/apache/clojure_mxnet/util_test.clj | 8 +++++++
python/mxnet/ndarray/ndarray.py | 4 ++--
python/mxnet/symbol/symbol.py | 4 ++--
.../src/main/scala/org/apache/mxnet/NDArray.scala | 9 ++++----
.../src/main/scala/org/apache/mxnet/Symbol.scala | 25 +++++++++++++++++++---
src/operator/tensor/init_op.h | 10 ++++++++-
tests/python/unittest/test_operator.py | 8 +++++++
12 files changed, 85 insertions(+), 18 deletions(-)
diff --git a/contrib/clojure-package/README.md b/contrib/clojure-package/README.md
index 5e7356c..ea678cc 100644
--- a/contrib/clojure-package/README.md
+++ b/contrib/clojure-package/README.md
@@ -107,7 +107,9 @@ The jars from maven with the needed MXNet native binaries in it. On startup, the
### Build from MXNET Source
-Checkout the latest sha from the main package
+First, ensure you have JDK 8 on your system. Later versions may produce cryptic build errors mentioning `scala.reflect.internal.MissingRequirementError`.
+
+Checkout the latest SHA from the main package:
`git clone --recursive https://github.com/apache/incubator-mxnet.git ~/mxnet`
`cd ~/mxnet`
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
index e37a8bc..7ca4ede 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
@@ -89,7 +89,7 @@
(NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
([start stop]
(arange start stop {})))
-
+
(defn slice
"Return a sliced NDArray that shares memory with current one."
([ndarray i]
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
index 42ae034..12135fb 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
@@ -135,10 +135,20 @@
([start stop {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
- (Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype))
+ (Symbol/arange (float start) ($/option (float stop)) step repeat false nil dtype))
([start stop]
(arange start stop {})))
+(defn arange-with-inference
+ "Behaves like arange operator, but infers the stop value from the output shape,
+ which must be known from the rest of the net."
+ ([start {:keys [step repeat dtype]
+ :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
+ :as opts}]
+ (Symbol/arange (float start) ($/option nil) step repeat true nil dtype))
+ ([start]
+ (arange-with-inference start {})))
+
;;; manually defined because of a conflicting arity of 2 with the auto-gen
(defn min
([sym-name kwargs-map symbol-list kwargs-map-1]
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
index a71a312..1b4b2ea 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj
@@ -222,6 +222,17 @@
(is (= 0 (count (executor/grad-arrays exec))))
(is (approx= 1e-4 result (-> (executor/outputs exec) (first))))))
+(deftest test-arange-with-inference
+ (let [arange (sym/arange-with-inference 0)
+ data (sym/variable "data")
+ added (sym/+ arange data)
+ result (range 0 4)
+ data-tmp (ndarray/zeros [4])
+ exec (sym/bind added (context/default-context) {"data" data-tmp})]
+ (executor/forward exec)
+ (is (= 0 (count (executor/grad-arrays exec))))
+ (is (approx= 1e-4 result (-> (executor/outputs exec) (first))))))
+
(deftest test-scalar-pow
(let [data (sym/variable "data")
shape-vec [1 1]
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
index dcdbea6..ecd54ca 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj
@@ -22,6 +22,8 @@
(if (and (number? x) (number? y))
(let [diff (Math/abs (- x y))]
(< diff tolerance))
- (reduce (fn [x y] (and x y))
- (map #(approx= tolerance %1 %2) x y))))
+ (and
+ (= (count x) (count y))
+ (reduce (fn [x y] (and x y))
+ (map #(approx= tolerance %1 %2) x y)))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index 5551fab..de34808 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -21,6 +21,7 @@
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.symbol :as sym]
+ [org.apache.clojure-mxnet.test-util :as test-util]
[clojure.spec.alpha :as s])
(:import (org.apache.mxnet Shape NDArrayFuncReturn NDArray)
(scala.collection Map Set)
@@ -183,3 +184,10 @@
(deftest test-validate
(is (nil? (util/validate! string? "foo" "Not a string!")))
(is (thrown-with-msg? Exception #"Not a string!" (util/validate! ::x 1 "Not a string!"))))
+
+(deftest test-approx=
+ (let [data1 [1 1 1 1]
+ data2 [1 1 1 1 9 9 9 9]
+ data3 [1 1 1 2]]
+ (is (not (test-util/approx= 1e-9 data1 data2)))
+ (is (test-util/approx= 2 data1 data3))))
\ No newline at end of file
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 46b21a9..d6d619f 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2475,7 +2475,7 @@ def moveaxis(tensor, source, destination):
# pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name
-def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
+def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t):
"""Returns evenly spaced values within a given interval.
Values are generated within the half-open interval [`start`, `stop`). In other
@@ -2519,7 +2519,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
if ctx is None:
ctx = current_context()
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
- dtype=dtype, ctx=str(ctx))
+ infer_range=infer_range, dtype=dtype, ctx=str(ctx))
# pylint: enable= no-member, protected-access, too-many-arguments
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 5f6cbd6..da5533f 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -2886,7 +2886,7 @@ def full(shape, val, dtype=None, **kwargs):
return _internal._full(shape=shape, dtype=dtype, value=float(val), **kwargs)
# pylint: disable=redefined-outer-name
-def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
+def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):
"""Returns evenly spaced values within a given interval.
Parameters
@@ -2911,7 +2911,7 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
if dtype is None:
dtype = _numpy.float32
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
- name=name, dtype=dtype)
+ infer_range=infer_range, name=name, dtype=dtype)
def histogram(a, bins=10, range=None, **kwargs):
"""Compute the histogram of the input data.
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 548c30b..8b5e1e0 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -407,11 +407,10 @@ object NDArray extends NDArrayBase {
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
- def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
- repeat: Int = 1, ctx: Context = Context.defaultCtx,
- dType: DType = Base.MX_REAL_TYPE): NDArray = {
- val params = Map("start" -> start, "step" -> step,
- "repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString())
+ def arange(start: Float, stop: Option[Float], step: Float,
+ repeat: Int, ctx: Context, dType: DType): NDArray = {
+ val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
+ "infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0)
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 194d368..e3e1a32 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -955,9 +955,28 @@ object Symbol extends SymbolBase {
* @return Symbol The created Symbol.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
- repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
- val params = Map("start" -> start, "step" -> step,
- "repeat" -> repeat, "dtype" -> dType.toString())
+ repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
+ arange(start, stop, step, repeat, infer_range = false, name, dType)
+ }
+
+ /**
+ * Returns evenly spaced values within a given interval.
+ * stop value can be infered from the output shape,
+ * which must be known from the rest of the net.
+ * @param start Start of interval. The default start value is 0.
+ * @param stop End of interval.
+ * @param step Spacing between values. The default step size is 1.
+ * @param repeat Number of times to repeat each element. The default repeat count is 1.
+ * @param infer_range Infer the stop value from output shape
+ * @param ctx Device context. Default context is the current default context.
+ * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
+ * @return NDArray of evenly spaced values in the specified range.
+ */
+ def arange(start: Float, stop: Option[Float], step: Float,
+ repeat: Int, infer_range: Boolean, name: String,
+ dType: DType): Symbol = {
+ val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
+ "infer_range" -> infer_range, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams)
}
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 4af3a40..304911a 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -123,6 +123,7 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
dmlc::optional<double> stop;
double step;
int repeat;
+ bool infer_range;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(RangeParam) {
@@ -140,6 +141,10 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
.set_default(1)
.describe("The repeating time of all elements."
" E.g repeat=3, the element a will be repeated three times --> a, a, a.");
+ DMLC_DECLARE_FIELD(infer_range)
+ .set_default(false)
+ .describe("Whether to infer the stop position from the start, step, repeat, and output tensor"
+ "size.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
@@ -176,7 +181,7 @@ struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> {
inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
RangeParam param;
param.Init(attrs->dict);
- if (!static_cast<bool>(param.stop)) {
+ if (!static_cast<bool>(param.infer_range) && !static_cast<bool>(param.stop)) {
param.stop = param.start;
param.start = 0;
}
@@ -471,6 +476,9 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
<< "Range does not support step=0, received " << param.step;
CHECK(param.repeat > 0)
<< "Range only supports repeat > 0, received " << param.repeat;
+ if (param.infer_range && !param.stop.has_value()) {
+ return false;
+ }
if (param.step > 0) {
CHECK(param.start < param.stop.value())
<< "Invalid range (start, stop, step) = "
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index fc6b814..fd60611 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3646,10 +3646,18 @@ def test_init():
nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype)
assert_almost_equal(np_out, nd_out.asnumpy())
+ def test_arange_inferstop():
+ s = mx.sym.arange(start=0, stop=None, infer_range=True)
+ s = mx.sym.elemwise_add(s, mx.sym.zeros(shape=[5]))
+ exe = s.bind(ctx=mx.cpu(), args={})
+ exe.forward()
+ assert_almost_equal(exe.outputs[0].asnumpy(), np.array([0,1,2,3,4]))
+
test_basic_val_init(mx.sym.zeros, np.zeros, (3, 4), np.float32)
test_basic_val_init(mx.sym.ones, np.ones, 3, np.int32)
test_basic_val_init(mx.sym.ones, np.ones, (2, 2, 3), np.float16)
test_arange()
+ test_arange_inferstop()
@with_seed()