You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/23 22:05:32 UTC

[incubator-tvm] branch master updated: [MXNET]DepthToSpace & SpaceToDepth Operator (#5408)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 6faacc6  [MXNET]DepthToSpace & SpaceToDepth Operator (#5408)
6faacc6 is described below

commit 6faacc6f9ed2f0ed59dda3216710e4df489210e4
Author: Samuel <si...@huawei.com>
AuthorDate: Fri Apr 24 03:35:25 2020 +0530

    [MXNET]DepthToSpace & SpaceToDepth Operator (#5408)
---
 python/tvm/relay/frontend/mxnet.py          | 16 ++++++++++++++
 tests/python/frontend/mxnet/test_forward.py | 34 +++++++++++++++++++++++++++++
 2 files changed, 50 insertions(+)

diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py
index 4edf0b8..be2f110 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -1073,6 +1073,20 @@ def _mx_one_hot(inputs, attrs):
     return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)
 
 
+def _mx_depth_to_space(inputs, attrs):
+    assert len(inputs) == 1
+    new_attrs = {}
+    new_attrs["block_size"] = attrs.get_int("block_size")
+    return _op.nn.depth_to_space(*inputs, **new_attrs)
+
+
+def _mx_space_to_depth(inputs, attrs):
+    assert len(inputs) == 1
+    new_attrs = {}
+    new_attrs["block_size"] = attrs.get_int("block_size")
+    return _op.nn.space_to_depth(*inputs, **new_attrs)
+
+
 def _mx_contrib_fifo_buffer(inputs, attrs):
     new_attrs = {}
     new_attrs['axis'] = attrs.get_int('axis')
@@ -1854,6 +1868,8 @@ _convert_map = {
     "make_loss"     : _mx_make_loss,
     "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
     "one_hot"           : _mx_one_hot,
+    "depth_to_space"    : _mx_depth_to_space,
+    "space_to_depth"    : _mx_space_to_depth,
     # vision
     "_contrib_BilinearResize2D" : _mx_resize,
     "_contrib_MultiBoxPrior" : _mx_multibox_prior,
diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py
index 4a9848e..10edff9 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -995,6 +995,38 @@ def test_forward_swap_axis():
     # _verify_swap_axis((4, 5), (5, 4), 0, 0)
 
 
+def test_forward_depth_to_space():
+    def verify(shape, blocksize=2):
+        x = np.random.uniform(size=shape).astype("float32")
+        ref_res = mx.nd.depth_to_space(mx.nd.array(x), blocksize)
+        mx_sym = mx.sym.depth_to_space(mx.sym.var("x"), blocksize)
+        shape_dict = {"x": x.shape, }
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(x)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+
+    verify((1, 18, 3, 3), 3)
+
+
+def test_forward_space_to_depth():
+    def verify(shape, blocksize=2):
+        x = np.random.uniform(size=shape).astype("float32")
+        ref_res = mx.nd.space_to_depth(mx.nd.array(x), blocksize)
+        mx_sym = mx.sym.space_to_depth(mx.sym.var("x"), blocksize)
+        shape_dict = {"x": x.shape, }
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(x)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+
+    verify((1, 1, 9, 9), 3)
+
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -1047,6 +1079,8 @@ if __name__ == '__main__':
     test_forward_instance_norm()
     test_forward_layer_norm()
     test_forward_one_hot()
+    test_forward_depth_to_space()
+    test_forward_space_to_depth()
     test_forward_convolution()
     test_forward_deconvolution()
     test_forward_cond()