You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/11/02 19:41:05 UTC

[incubator-tvm] branch ci-docker-staging updated: [CI] Torch 1.7 update staging (#6825)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch ci-docker-staging
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/ci-docker-staging by this push:
     new 93cd1a1  [CI] Torch 1.7 update staging (#6825)
93cd1a1 is described below

commit 93cd1a1d166f70f86061f2c6948343fb16038576
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Nov 3 04:40:49 2020 +0900

    [CI] Torch 1.7 update staging (#6825)
    
    * fix norm and linspace test
    
    * fix upsampling conversion
    
    * update install script
    
    * remove print
    
    * fix pylint
    
    * disable quantized googlenet test
    
    * fix for object detection test
    
    * fix pylint
    
    * update ci-gpu
    
    * fix typo
    
    * updated supported version
---
 Jenkinsfile                                        |  2 +-
 docker/install/ubuntu_install_onnx.sh              |  2 +-
 python/tvm/relay/frontend/pytorch.py               | 60 +++++++++++++---------
 tests/python/frontend/pytorch/qnn_test.py          |  3 +-
 tests/python/frontend/pytorch/test_forward.py      |  4 +-
 .../frontend/deploy_object_detection_pytorch.py    |  6 +--
 tutorials/frontend/from_pytorch.py                 |  6 +--
 7 files changed, 47 insertions(+), 36 deletions(-)

diff --git a/Jenkinsfile b/Jenkinsfile
index 17ddbab..079001f 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -45,7 +45,7 @@
 
 // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. -->
 ci_lint = "tlcpack/ci-lint:v0.62"
-ci_gpu = "tlcpack/ci-gpu:v0.71"
+ci_gpu = "tlcpack/ci-gpu:v0.72"
 ci_cpu = "tlcpack/ci-cpu:v0.71"
 ci_wasm = "tlcpack/ci-wasm:v0.70"
 ci_i386 = "tlcpack/ci-i386:v0.71"
diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh
index 2ad6019..a92a024 100755
--- a/docker/install/ubuntu_install_onnx.sh
+++ b/docker/install/ubuntu_install_onnx.sh
@@ -28,4 +28,4 @@ pip3 install onnxruntime==1.0.0
 # not expose that in the wheel!!!
 pip3 install future
 
-pip3 install torch==1.4.0 torchvision==0.5.0
+pip3 install torch==1.7.0 torchvision==0.8.1
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index d8c0769..2fd2078 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -21,6 +21,7 @@
 import itertools
 import logging
 import sys
+import math
 
 import numpy as np
 
@@ -168,7 +169,6 @@ def _min():
 
 def _unary(name):
     def _impl(inputs, input_types):
-        input_type = input_types[0]
         # this is just to ensure tensor input
         (data,) = _pytorch_promote_types(inputs[:1], input_types[:1])
         return get_relay_op(name)(data)
@@ -1552,7 +1552,7 @@ def _frobenius_norm():
         axis = None
         keepdims = False
         if len(inputs) > 2:
-            axis = inputs[1]
+            axis = inputs[1] if len(inputs[1]) > 0 else None
             keepdims = bool(inputs[2])
 
         return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims))
@@ -1847,18 +1847,33 @@ def _to():
     return _impl
 
 
-def _upsample(method, prelude):
-    def _impl(inputs, input_types):
-        out_size = []
+def _get_upsample_out_size(inputs, method):
+    # This assumes a static shape
+    out_size = []
+    if inputs[1] is not None:
         for size in inputs[1]:
             if not isinstance(size, int):
                 out_size.append(int(_infer_value(size, {}).asnumpy()))
             else:
                 out_size.append(size)
+    else:
+        scale_index = 3 if method in ["bilinear", "trilinear"] else 2
+        scales = inputs[scale_index]
+        assert scales is not None, "neither out size nor scale provided"
+        assert isinstance(scales, list)
+        ishape = _infer_shape(inputs[0])
+        for i, scale in enumerate(scales):
+            out_size.append(int(math.floor(float(ishape[2 + i]) * scale)))
+
+    return out_size
 
+
+def _upsample(method, prelude):
+    def _impl(inputs, input_types):
         data = inputs[0]
+        out_size = _get_upsample_out_size(inputs, method)
 
-        if len(inputs) > 2:
+        if len(inputs) > 2 and method == "bilinear":
             align_corners = inputs[2]
         else:
             align_corners = False
@@ -1874,17 +1889,13 @@ def _upsample(method, prelude):
             return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
 
         if _is_quantized_tensor(data, prelude):
-            # Torch version > 1.4 changed upsampling API
-            if is_version_greater_than("1.4.0"):
-                num_inputs = 7
-            else:
-                num_inputs = 5
-
-            assert len(inputs) == num_inputs, "Input quant param not found in op inputs"
-
+            # input qparams are manually appended by us
+            assert isinstance(inputs[-2], float)
+            assert isinstance(inputs[-1], int)
             input_scale = _expr.const(inputs[-2])
             input_zero_point = _expr.const(inputs[-1])
             return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func)
+
         return func(data)
 
     return _impl
@@ -1892,17 +1903,10 @@ def _upsample(method, prelude):
 
 def _upsample3d(method):
     def _impl(inputs, input_types):
-        if isinstance(inputs[1], _expr.Var):
-            out_size = _infer_shape(inputs[1])
-        elif _is_int_seq(inputs[1]):
-            out_size = inputs[1]
-        elif isinstance(inputs[1], list):
-            infer_res = [_infer_value(size, {}) for size in inputs[1]]
-            out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res]
-
         data = inputs[0]
+        out_size = _get_upsample_out_size(inputs, method)
 
-        if len(inputs) > 2:
+        if len(inputs) > 2 and method == "trilinear":
             align_corners = inputs[2]
         else:
             align_corners = False
@@ -1983,8 +1987,7 @@ def _bitwise_xor():
 
 def _logical_not():
     def _impl(inputs, input_types):
-        data = inputs[0]
-
+        data = _wrap_const(inputs[0])
         return _op.logical_not(_op.cast(data, "bool"))
 
     return _impl
@@ -2732,6 +2735,7 @@ def _get_convert_map(prelude, default_dtype):
         "aten::empty": _empty(),
         "aten::bincount": _bincount(),
         "aten::scatter_add": _scatter_add(),
+        "aten::__not__": _logical_not(),
     }
     return convert_map
 
@@ -2798,6 +2802,7 @@ def _report_missing_conversion(op_names, convert_map):
         "prim::ListUnpack",
         "prim::TupleConstruct",
         "prim::TupleUnpack",
+        "prim::RaiseException",
         "prim::If",
         "prim::Loop",
     ]
@@ -2903,6 +2908,8 @@ def _get_operator_nodes(nodes):
     ops = []
     # Traverse nodes and add to graph
     for node in nodes:
+        if node.outputsSize() == 0:
+            continue
         if node.outputsSize() > 1:
             node_name = "_".join(_get_output_names(node))
         else:
@@ -3286,6 +3293,9 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude, defau
             else:
                 unpacked = _unpack_tuple(inputs[0])
             outputs.update(zip(_get_output_names(op_node), unpacked))
+        elif operator == "prim::prim::RaiseException":
+            logging.warning("raising exceptions is ignored")
+            outputs[node_name] = None
         elif operator == "prim::If":
             if_out = convert_if(op_node, outputs, convert_map, prelude, default_dtype=default_dtype)
             outputs[node_name] = if_out
diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py
index 1851e31..9781eb5 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -367,7 +367,8 @@ def test_quantized_imagenet():
             # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug
             # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756
             # ("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
-            ("googlenet", qgooglenet(pretrained=True), per_channel),
+            # tracing quantized googlenet broken as of v1.6
+            # ("googlenet", qgooglenet(pretrained=True), per_channel),
         ]
 
     results = []
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index e997ebe..4dec5f7 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -2535,7 +2535,7 @@ def test_forward_linspace():
 
     class Linspace1(Module):
         def forward(self, *args):
-            return torch.linspace(5, 10)
+            return torch.linspace(5, 10, steps=100)
 
     class Linspace2(Module):
         def forward(self, *args):
@@ -2559,7 +2559,7 @@ def test_forward_linspace():
 
     class Linspace7(Module):
         def forward(self, *args):
-            return torch.linspace(1, 4, dtype=torch.float32)
+            return torch.linspace(1, 4, steps=100, dtype=torch.float32)
 
     class Linspace8(Module):
         def forward(self, *args):
diff --git a/tutorials/frontend/deploy_object_detection_pytorch.py b/tutorials/frontend/deploy_object_detection_pytorch.py
index 6408685..2852dd3 100644
--- a/tutorials/frontend/deploy_object_detection_pytorch.py
+++ b/tutorials/frontend/deploy_object_detection_pytorch.py
@@ -27,8 +27,8 @@ A quick solution is to install via pip
 
 .. code-block:: bash
 
-    pip install torch==1.4.0
-    pip install torchvision==0.5.0
+    pip install torch==1.7.0
+    pip install torchvision==0.8.1
 
 or please refer to official site
 https://pytorch.org/get-started/locally/
@@ -36,7 +36,7 @@ https://pytorch.org/get-started/locally/
 PyTorch versions should be backwards compatible but should be used
 with the proper TorchVision version.
 
-Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
+Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may
 be unstable.
 """
 
diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py
index 33a0588..b5bcdf6 100644
--- a/tutorials/frontend/from_pytorch.py
+++ b/tutorials/frontend/from_pytorch.py
@@ -28,8 +28,8 @@ A quick solution is to install via pip
 
 .. code-block:: bash
 
-    pip install torch==1.4.0
-    pip install torchvision==0.5.0
+    pip install torch==1.7.0
+    pip install torchvision==0.8.1
 
 or please refer to official site
 https://pytorch.org/get-started/locally/
@@ -37,7 +37,7 @@ https://pytorch.org/get-started/locally/
 PyTorch versions should be backwards compatible but should be used
 with the proper TorchVision version.
 
-Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
+Currently, TVM supports PyTorch 1.7 and 1.4. Other versions may
 be unstable.
 """