You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/04/11 05:03:06 UTC

[incubator-tvm] branch master updated: [PYTORCH]Abs, Arange, Softplus ops (#5295)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 5b37d4c  [PYTORCH]Abs, Arange, Softplus ops (#5295)
5b37d4c is described below

commit 5b37d4c15378e872c279ca5edbcb077d1a5fd20b
Author: Samuel <si...@huawei.com>
AuthorDate: Sat Apr 11 10:32:58 2020 +0530

    [PYTORCH]Abs, Arange, Softplus ops (#5295)
    
    * [PYTHON]Abs, Arange, Softplus ops
    
    * Review comments updated
---
 python/tvm/relay/frontend/pytorch.py          | 52 +++++++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 66 +++++++++++++++++++++++++++
 2 files changed, 118 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index b8b32e7..a542ccc 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -57,6 +57,33 @@ def _elemwise(name):
         return get_relay_op(name)(data0, data1)
     return _impl
 
+def _abs():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        return _op.abs(data)
+    return _impl
+
+def _arange():
+    def _impl(inputs, input_types):
+        if len(inputs) == 5:
+            dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1])
+            start = _create_typed_const(0, dtype)
+            stop = _create_typed_const(inputs[0], dtype)
+            step = _create_typed_const(1, dtype)
+        elif len(inputs) == 7:
+            dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
+            start = _create_typed_const(inputs[0], dtype)
+            stop = _create_typed_const(inputs[1], dtype)
+            step = _create_typed_const(inputs[2], dtype)
+        else:
+            msg = "Unknown number of arguments (%d) to parse." % (len(inputs))
+            raise AssertionError(msg)
+        return _op.transform.arange(start=start,
+                                    stop=stop,
+                                    step=step,
+                                    dtype=_convert_data_type(dtype))
+    return _impl
+
 def _squeeze():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -732,6 +759,13 @@ def _sigmoid():
         return _op.tensor.sigmoid(data)
     return _impl
 
+def _softplus():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        beta = _expr.const(float(inputs[1]))
+        return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
+    return _impl
+
 def _avg_pool2d():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1044,6 +1078,21 @@ def _Float():
     return _impl
 
 # Helper functions for operator implementation
+def _convert_dtype_value(val):
+    convert_torch_dtype_map = {7:"torch.float64",
+                               6:"torch.float32",
+                               5:"torch.float16",
+                               4:"torch.int64",
+                               3:"torch.int32",
+                               2:"torch.int16",
+                               1:"torch.int8",
+                               0:"torch.unit8",
+                               None:"torch.int64"} # Default is torch.int64
+    if val in convert_torch_dtype_map:
+        return convert_torch_dtype_map[val]
+    else:
+        msg = "Torch data type value %d is not handled yet." % (val)
+        raise NotImplementedError(msg)
 
 def _convert_data_type(input_type):
     if input_type in ["double", "torch.float64"]:
@@ -1118,6 +1167,8 @@ _convert_map = {
     "aten::pow"                             : _elemwise("power"),
     "aten::div"                             : _elemwise("divide"),
     "aten::div_"                            : _elemwise("divide"),
+    "aten::abs"                             : _abs(),
+    "aten::arange"                          : _arange(),
     "aten::ones"                            : _ones(),
     "aten::zeros"                           : _zeros(),
     "aten::reciprocal"                      : _reciprocal(),
@@ -1167,6 +1218,7 @@ _convert_map = {
     "aten::clone"                           : _clone(),
     "aten::log_softmax"                     : _log_softmax(),
     "aten::sigmoid"                         : _sigmoid(),
+    "aten::softplus"                        : _softplus(),
     "aten::avg_pool2d"                      : _avg_pool2d(),
     "aten::avg_pool3d"                      : _avg_pool3d(),
     "aten::dropout"                         : _dropout(),
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 4226463..d60ab9e 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -375,6 +375,54 @@ def test_forward_squeeze():
     verify_model(Squeeze1().float().eval(), input_data=input_data)
     verify_model(Squeeze2().float().eval(), input_data=input_data)
 
+def test_forward_arange():
+    torch.set_grad_enabled(False)
+
+    class Arange1(Module):
+        def forward(self, *args):
+            return torch.arange(5)
+    class Arange2(Module):
+        def forward(self, *args):
+            return torch.arange(2.5)
+    class Arange3(Module):
+        def forward(self, *args):
+            return torch.arange(1, 4)
+    class Arange4(Module):
+        def forward(self, *args):
+            return torch.arange(1, 2.5, 0.5)
+    class Arange5(Module):
+        def forward(self, *args):
+            return torch.arange(1, 2, 1, dtype=torch.int32)
+    class Arange6(Module):
+        def forward(self, *args):
+            return torch.arange(start=1, end=6, step=2)
+    class Arange7(Module):
+        def forward(self, *args):
+            return torch.arange(1, 4, dtype=torch.float32)
+    class Arange8(Module):
+        def forward(self, *args):
+            return torch.arange(1, 2, 1, dtype=torch.int16)
+
+    verify_model(Arange1().float().eval())
+    verify_model(Arange2().float().eval())
+    verify_model(Arange3().float().eval())
+    verify_model(Arange4().float().eval())
+    verify_model(Arange5().float().eval())
+    verify_model(Arange6().float().eval())
+    verify_model(Arange7().float().eval())
+    verify_model(Arange8().float().eval())
+
+def test_forward_abs():
+    torch.set_grad_enabled(False)
+    input_shape = [2, 1, 10, 1, 10]
+
+    class Abs1(Module):
+        def forward(self, *args):
+            return args[0].abs()
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Abs1().float().eval(), input_data=input_data)
+
 def test_forward_concatenate():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]
@@ -445,6 +493,20 @@ def test_forward_selu():
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.SELU().eval(), input_data=input_data)
 
+def test_forward_softplus():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+    input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.Softplus().eval(), input_data=input_data)
+    verify_model(torch.nn.Softplus(beta=1.5, threshold=20).eval(), input_data=input_data)
+    verify_model(torch.nn.Softplus(beta=5, threshold=10).eval(), input_data=input_data)
+
+def test_forward_softsign():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+    input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.Softsign().eval(), input_data=input_data)
+
 def test_forward_log_sigmoid():
     torch.set_grad_enabled(False)
     input_shape = [10, 10]
@@ -1254,6 +1316,8 @@ if __name__ == "__main__":
     test_forward_view()
     test_forward_select()
     test_forward_clone()
+    test_forward_softplus()
+    test_forward_softsign()
     test_forward_logsoftmax()
     test_forward_sigmoid()
     test_forward_dense()
@@ -1264,6 +1328,8 @@ if __name__ == "__main__":
     test_forward_mean()
     test_forward_expand()
     test_forward_pow()
+    test_forward_abs()
+    test_forward_arange()
     test_forward_chunk()
     test_forward_split()
     test_upsample()