You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/04/11 05:03:06 UTC
[incubator-tvm] branch master updated: [PYTORCH]Abs, Arange,
Softplus ops (#5295)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 5b37d4c [PYTORCH]Abs, Arange, Softplus ops (#5295)
5b37d4c is described below
commit 5b37d4c15378e872c279ca5edbcb077d1a5fd20b
Author: Samuel <si...@huawei.com>
AuthorDate: Sat Apr 11 10:32:58 2020 +0530
[PYTORCH]Abs, Arange, Softplus ops (#5295)
* [PYTHON]Abs, Arange, Softplus ops
* Review comments updated
---
python/tvm/relay/frontend/pytorch.py | 52 +++++++++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 66 +++++++++++++++++++++++++++
2 files changed, 118 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index b8b32e7..a542ccc 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -57,6 +57,33 @@ def _elemwise(name):
return get_relay_op(name)(data0, data1)
return _impl
+def _abs():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ return _op.abs(data)
+ return _impl
+
+def _arange():
+ def _impl(inputs, input_types):
+ if len(inputs) == 5:
+ dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1])
+ start = _create_typed_const(0, dtype)
+ stop = _create_typed_const(inputs[0], dtype)
+ step = _create_typed_const(1, dtype)
+ elif len(inputs) == 7:
+ dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
+ start = _create_typed_const(inputs[0], dtype)
+ stop = _create_typed_const(inputs[1], dtype)
+ step = _create_typed_const(inputs[2], dtype)
+ else:
+ msg = "Unknown number of arguments (%d) to parse." % (len(inputs))
+ raise AssertionError(msg)
+ return _op.transform.arange(start=start,
+ stop=stop,
+ step=step,
+ dtype=_convert_data_type(dtype))
+ return _impl
+
def _squeeze():
def _impl(inputs, input_types):
data = inputs[0]
@@ -732,6 +759,13 @@ def _sigmoid():
return _op.tensor.sigmoid(data)
return _impl
+def _softplus():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ beta = _expr.const(float(inputs[1]))
+ return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
+ return _impl
+
def _avg_pool2d():
def _impl(inputs, input_types):
data = inputs[0]
@@ -1044,6 +1078,21 @@ def _Float():
return _impl
# Helper functions for operator implementation
+def _convert_dtype_value(val):
+ convert_torch_dtype_map = {7:"torch.float64",
+ 6:"torch.float32",
+ 5:"torch.float16",
+ 4:"torch.int64",
+ 3:"torch.int32",
+ 2:"torch.int16",
+ 1:"torch.int8",
+ 0:"torch.unit8",
+ None:"torch.int64"} # Default is torch.int64
+ if val in convert_torch_dtype_map:
+ return convert_torch_dtype_map[val]
+ else:
+ msg = "Torch data type value %d is not handled yet." % (val)
+ raise NotImplementedError(msg)
def _convert_data_type(input_type):
if input_type in ["double", "torch.float64"]:
@@ -1118,6 +1167,8 @@ _convert_map = {
"aten::pow" : _elemwise("power"),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
+ "aten::abs" : _abs(),
+ "aten::arange" : _arange(),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::reciprocal" : _reciprocal(),
@@ -1167,6 +1218,7 @@ _convert_map = {
"aten::clone" : _clone(),
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
+ "aten::softplus" : _softplus(),
"aten::avg_pool2d" : _avg_pool2d(),
"aten::avg_pool3d" : _avg_pool3d(),
"aten::dropout" : _dropout(),
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 4226463..d60ab9e 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -375,6 +375,54 @@ def test_forward_squeeze():
verify_model(Squeeze1().float().eval(), input_data=input_data)
verify_model(Squeeze2().float().eval(), input_data=input_data)
+def test_forward_arange():
+ torch.set_grad_enabled(False)
+
+ class Arange1(Module):
+ def forward(self, *args):
+ return torch.arange(5)
+ class Arange2(Module):
+ def forward(self, *args):
+ return torch.arange(2.5)
+ class Arange3(Module):
+ def forward(self, *args):
+ return torch.arange(1, 4)
+ class Arange4(Module):
+ def forward(self, *args):
+ return torch.arange(1, 2.5, 0.5)
+ class Arange5(Module):
+ def forward(self, *args):
+ return torch.arange(1, 2, 1, dtype=torch.int32)
+ class Arange6(Module):
+ def forward(self, *args):
+ return torch.arange(start=1, end=6, step=2)
+ class Arange7(Module):
+ def forward(self, *args):
+ return torch.arange(1, 4, dtype=torch.float32)
+ class Arange8(Module):
+ def forward(self, *args):
+ return torch.arange(1, 2, 1, dtype=torch.int16)
+
+ verify_model(Arange1().float().eval())
+ verify_model(Arange2().float().eval())
+ verify_model(Arange3().float().eval())
+ verify_model(Arange4().float().eval())
+ verify_model(Arange5().float().eval())
+ verify_model(Arange6().float().eval())
+ verify_model(Arange7().float().eval())
+ verify_model(Arange8().float().eval())
+
+def test_forward_abs():
+ torch.set_grad_enabled(False)
+ input_shape = [2, 1, 10, 1, 10]
+
+ class Abs1(Module):
+ def forward(self, *args):
+ return args[0].abs()
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Abs1().float().eval(), input_data=input_data)
+
def test_forward_concatenate():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
@@ -445,6 +493,20 @@ def test_forward_selu():
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.SELU().eval(), input_data=input_data)
+def test_forward_softplus():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3, 10, 10]
+ input_data = torch.rand(input_shape).float()
+ verify_model(torch.nn.Softplus().eval(), input_data=input_data)
+ verify_model(torch.nn.Softplus(beta=1.5, threshold=20).eval(), input_data=input_data)
+ verify_model(torch.nn.Softplus(beta=5, threshold=10).eval(), input_data=input_data)
+
+def test_forward_softsign():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3, 10, 10]
+ input_data = torch.rand(input_shape).float()
+ verify_model(torch.nn.Softsign().eval(), input_data=input_data)
+
def test_forward_log_sigmoid():
torch.set_grad_enabled(False)
input_shape = [10, 10]
@@ -1254,6 +1316,8 @@ if __name__ == "__main__":
test_forward_view()
test_forward_select()
test_forward_clone()
+ test_forward_softplus()
+ test_forward_softsign()
test_forward_logsoftmax()
test_forward_sigmoid()
test_forward_dense()
@@ -1264,6 +1328,8 @@ if __name__ == "__main__":
test_forward_mean()
test_forward_expand()
test_forward_pow()
+ test_forward_abs()
+ test_forward_arange()
test_forward_chunk()
test_forward_split()
test_upsample()