You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2020/10/07 18:19:03 UTC
[incubator-mxnet] branch master updated: [BUGFIX] [Numpy] MXNet
fp16 initialization bug #19118 (#19270)
This is an automated email from the ASF dual-hosted git repository.
sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c2d56dc [BUGFIX] [Numpy] MXNet fp16 initialization bug #19118 (#19270)
c2d56dc is described below
commit c2d56dcde0c4256c91635670d49d90448865ab25
Author: Anshu Trivedi <an...@gmail.com>
AuthorDate: Wed Oct 7 23:47:01 2020 +0530
[BUGFIX] [Numpy] MXNet fp16 initialization bug #19118 (#19270)
* fix: MXNet fp16 initialization bug #19118
* add fix for a couple more initializers, and test
Co-authored-by: Sheng Zha <zh...@amazon.com>
---
python/mxnet/initializer.py | 8 ++++----
tests/python/unittest/test_smoke.py | 29 +++++++++++++++++++++++++++--
2 files changed, 31 insertions(+), 6 deletions(-)
diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py
index 28f8af1..60a1d3d 100644
--- a/python/mxnet/initializer.py
+++ b/python/mxnet/initializer.py
@@ -507,7 +507,7 @@ class Uniform(Initializer):
def _init_weight(self, _, arr):
uniform_fn = _mx_np.random.uniform if is_np_array() else random.uniform
- uniform_fn(-self.scale, self.scale, arr.shape, out=arr)
+ uniform_fn(-self.scale, self.scale, arr.shape, dtype=arr.dtype, out=arr)
@register
class Normal(Initializer):
@@ -541,7 +541,7 @@ class Normal(Initializer):
def _init_weight(self, _, arr):
normal_fn = _mx_np.random.normal if is_np_array() else random.normal
- normal_fn(0, self.sigma, arr.shape, out=arr)
+ normal_fn(0, self.sigma, arr.shape, dtype=arr.dtype, out=arr)
@register
class Orthogonal(Initializer):
@@ -641,10 +641,10 @@ class Xavier(Initializer):
scale = np.sqrt(self.magnitude / factor)
if self.rnd_type == "uniform":
uniform_fn = _mx_np.random.uniform if is_np_array() else random.uniform
- uniform_fn(-scale, scale, arr.shape, out=arr)
+ uniform_fn(-scale, scale, arr.shape, dtype=arr.dtype, out=arr)
elif self.rnd_type == "gaussian":
normal_fn = _mx_np.random.normal if is_np_array() else random.normal
- normal_fn(0, scale, arr.shape, out=arr)
+ normal_fn(0, scale, arr.shape, dtype=arr.dtype, out=arr)
else:
raise ValueError("Unknown random type")
diff --git a/tests/python/unittest/test_smoke.py b/tests/python/unittest/test_smoke.py
index c14310c..04f8679 100644
--- a/tests/python/unittest/test_smoke.py
+++ b/tests/python/unittest/test_smoke.py
@@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.
-from mxnet import np, npx, use_np, autograd
+from mxnet import np, npx, use_np, autograd, initializer, gluon
from common import setup_module, teardown_module, with_environment
+import pytest
@use_np
@with_environment('MXNET_ENGINE_TYPE', 'NaiveEngine')
@@ -66,4 +67,28 @@ def test_18934_empty_leaky_relu():
autograd.mark_variables([arr], [arr_grad])
with autograd.record():
res = npx.leaky_relu(arr)
- res.backward()
\ No newline at end of file
+ res.backward()
+
+@use_np
+@pytest.mark.parametrize('initializer',[
+ 'zeros', 'ones', initializer.Constant(3),
+ initializer.Uniform(),
+ initializer.Normal(),
+ initializer.Orthogonal(),
+ initializer.Orthogonal(rand_type='normal'),
+ initializer.Xavier(),
+ initializer.Xavier(rnd_type='gaussian'),
+ initializer.MSRAPrelu(),
+ initializer.MSRAPrelu(factor_type='in'),
+ initializer.MSRAPrelu(factor_type='out'),
+ initializer.LSTMBias(),
+])
+@pytest.mark.parametrize('dtype', [
+ 'float32', 'float64'
+])
+def test_19118(initializer, dtype):
+ net = gluon.nn.Dense(16, in_units=16)
+ net.cast(dtype)
+ net.initialize(initializer)
+ net.hybridize()
+ net(np.zeros((16, 16), dtype=dtype))