You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/04/15 10:18:12 UTC
[incubator-tvm] branch master updated: [PYTORCH]Take,
Topk op support (#5332)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new b1364eb [PYTORCH]Take, Topk op support (#5332)
b1364eb is described below
commit b1364ebbedb6bf540d1d2610d772ac441e2f7cb5
Author: Samuel <si...@huawei.com>
AuthorDate: Wed Apr 15 15:48:03 2020 +0530
[PYTORCH]Take, Topk op support (#5332)
* [PYTORCH]take, topk op support
* Ci Failure fix
---
python/tvm/relay/frontend/pytorch.py | 35 ++++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 57 +++++++++++++++++++++++++++
2 files changed, 92 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 38a811d..0acebe4 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -272,6 +272,39 @@ def _select():
return _op.transform.take(data, index, axis=dim)
return _impl
+def _take():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ import torch
+
+ if isinstance(inputs[1], _expr.Var):
+ indices = _op.cast(inputs[1], "int32")
+ elif isinstance(inputs[1], torch.Tensor):
+ indices = _wrap_const(inputs[1].numpy())
+ else:
+ msg = "Data type %s could not be parsed in take operator." % (type(inputs[1]))
+ raise AssertionError(msg)
+
+ return _op.transform.take(data, indices=indices)
+ return _impl
+
+def _topk():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ k = int(inputs[1])
+ axis = int(inputs[2])
+ is_ascend = not bool(inputs[3])
+ sort = bool(inputs[4])
+
+ if not sort:
+ msg = "Currently supports only sorted output for topk operator."
+ raise AssertionError(msg)
+
+ outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both")
+
+ return outs[0], outs[1]
+ return _impl
+
def _reciprocal():
def _impl(inputs, input_types):
data = inputs[0]
@@ -1416,6 +1449,8 @@ def _get_convert_map(prelude):
"aten::split" : _split(),
"aten::split_with_sizes" : _split_with_sizes(),
"aten::select" : _select(),
+ "aten::take" : _take(),
+ "aten::topk" : _topk(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
"aten::prelu" : _prelu(),
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index d9d280f..c562fce 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1545,6 +1545,61 @@ def test_forward_round():
verify_model(Round1().float().eval(), input_data=input_data)
+def test_forward_take():
+ torch.set_grad_enabled(False)
+ class Take1(Module):
+ def forward(self, *args):
+ indices = torch.tensor([[0,0],[1,0]])
+ if torch.cuda.is_available():
+ indices = indices.cuda()
+ return torch.take(args[0], indices)
+
+ class Take2(Module):
+ def forward(self, *args):
+ return torch.take(args[0], args[1])
+
+ input_data = torch.tensor([[1,2],[3,4]])
+ verify_model(Take1().float().eval(), input_data=input_data)
+ indices = torch.tensor([[0,0],[1,0]])
+ verify_model(Take2().float().eval(), input_data=[input_data, indices])
+
+
+def test_forward_topk():
+ torch.set_grad_enabled(False)
+ class Topk1(Module):
+ def forward(self, *args):
+ return torch.topk(args[0], k=3)
+
+ class Topk2(Module):
+ def forward(self, *args):
+ return torch.topk(args[0], k=3, dim=-2)
+
+ class Topk3(Module):
+ def forward(self, *args):
+ return torch.topk(args[0], k=3, dim=3)
+
+ class Topk4(Module):
+ def forward(self, *args):
+ return torch.topk(args[0], k=3, largest=True)
+
+ class Topk5(Module):
+ def forward(self, *args):
+ return torch.topk(args[0], k=3, largest=False)
+
+ class Topk6(Module):
+ def forward(self, *args):
+ return torch.topk(args[0], k=3, sorted=True)
+
+ input_shape = [1, 3, 10, 10]
+ input_data = torch.rand(input_shape).float()
+ verify_model(Topk1().float().eval(), input_data=input_data)
+ verify_model(Topk2().float().eval(), input_data=input_data)
+ verify_model(Topk3().float().eval(), input_data=input_data)
+ verify_model(Topk4().float().eval(), input_data=input_data)
+ verify_model(Topk5().float().eval(), input_data=input_data)
+ verify_model(Topk6().float().eval(), input_data=input_data)
+
+
if __name__ == "__main__":
# Single operator tests
test_forward_add()
@@ -1587,6 +1642,8 @@ if __name__ == "__main__":
test_forward_size()
test_forward_view()
test_forward_select()
+ test_forward_take()
+ test_forward_topk()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()