You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ke...@apache.org on 2020/04/14 06:09:28 UTC
[incubator-tvm] branch master updated: [Frontend|MXNet] SwapAxis
operator support (#5246)
This is an automated email from the ASF dual-hosted git repository.
kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new b7545eb [Frontend|MXNet] SwapAxis operator support (#5246)
b7545eb is described below
commit b7545eb5ca87507ea04ccbe96c1a02040bef26be
Author: Mahesh Ambule <15...@users.noreply.github.com>
AuthorDate: Tue Apr 14 11:39:21 2020 +0530
[Frontend|MXNet] SwapAxis operator support (#5246)
* MXNet swap axis
* MXNet swap axis
* swap axis review comment
* swap axis review comment
---
python/tvm/relay/frontend/mxnet.py | 12 ++++++++++++
tests/python/frontend/mxnet/test_forward.py | 13 +++++++++++++
2 files changed, 25 insertions(+)
diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py
index 5c8e726..e6aa5f1 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -127,6 +127,17 @@ def _mx_unravel_index(inputs, attrs):
return _op.unravel_index(inputs[0], shape_expr)
+def _mx_swap_axis(inputs, attrs):
+ assert len(inputs) == 1
+ dim1 = attrs.get_int('dim1')
+ dim2 = attrs.get_int('dim2')
+ shape = _infer_type(inputs[0]).checked_type.shape
+ axes = list(range(len(shape)))
+ axes[dim1] = dim2
+ axes[dim2] = dim1
+ return _op.transpose(inputs[0], axes=axes)
+
+
def _mx_zeros(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
@@ -1813,6 +1824,7 @@ _convert_map = {
"slice_axis" : _mx_slice_axis,
"SliceChannel" : _mx_split,
"split" : _mx_split,
+ "SwapAxis" : _mx_swap_axis,
"expand_dims" : _mx_expand_dims,
"Concat" : _mx_concat,
"concat" : _mx_concat,
diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py
index eb308c5..4a9848e 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -983,6 +983,18 @@ def test_forward_unravel_index():
# verify([0, 1, 2, 5], [2, 2], dtype)
+def test_forward_swap_axis():
+ def _verify_swap_axis(in_shape, out_shape, dim1, dim2):
+ data = mx.sym.var('data')
+ mx_sym = mx.sym.swapaxes(data, dim1, dim2)
+ verify_mxnet_frontend_impl(mx_sym, in_shape, out_shape)
+
+ _verify_swap_axis((4, 5), (5, 4), 0, 1)
+ _verify_swap_axis((2, 4, 4, 5), (2, 5, 4, 4), 1, 3)
+ # MXNet errors out when dim1 == dim2
+ # _verify_swap_axis((4, 5), (5, 4), 0, 0)
+
+
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
@@ -1040,3 +1052,4 @@ if __name__ == '__main__':
test_forward_cond()
test_forward_make_loss()
test_forward_unravel_index()
+ test_forward_swap_axis()
\ No newline at end of file