You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/03 13:38:13 UTC

[GitHub] [incubator-tvm] masahi commented on a change in pull request #6811: [Torch, CI] Update to PyTorch 1.7

masahi commented on a change in pull request #6811:
URL: https://github.com/apache/incubator-tvm/pull/6811#discussion_r515871753



##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1847,18 +1847,33 @@ def _impl(inputs, input_types):
     return _impl
 
 
-def _upsample(method, prelude):
-    def _impl(inputs, input_types):
-        out_size = []
+def _get_upsample_out_size(inputs, method):
+    # This assumes a static shape
+    out_size = []
+    if inputs[1] is not None:
         for size in inputs[1]:
             if not isinstance(size, int):
                 out_size.append(int(_infer_value(size, {}).asnumpy()))
             else:
                 out_size.append(size)
+    else:
+        scale_index = 3 if method in ["bilinear", "trilinear"] else 2
+        scales = inputs[scale_index]
+        assert scales is not None, "neither out size nor scale provided"
+        assert isinstance(scales, list)
+        ishape = _infer_shape(inputs[0])
+        for i, scale in enumerate(scales):
+            out_size.append(int(math.floor(math.floor(ishape[2 + i] * scale))))

Review comment:
       haha good find, the inner one is a typo for `float`

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1874,35 +1889,24 @@ def func(x):
             return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
 
         if _is_quantized_tensor(data, prelude):
-            # Torch version > 1.4 changed upsampling API
-            if is_version_greater_than("1.4.0"):
-                num_inputs = 7
-            else:
-                num_inputs = 5
-
-            assert len(inputs) == num_inputs, "Input quant param not found in op inputs"
-
+            # input qparams are manually appended by us
+            assert isinstance(inputs[-2], float)
+            assert isinstance(inputs[-1], int)
             input_scale = _expr.const(inputs[-2])
             input_zero_point = _expr.const(inputs[-1])
             return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func)
+
         return func(data)
 
     return _impl
 
 
 def _upsample3d(method):
     def _impl(inputs, input_types):
-        if isinstance(inputs[1], _expr.Var):
-            out_size = _infer_shape(inputs[1])
-        elif _is_int_seq(inputs[1]):
-            out_size = inputs[1]
-        elif isinstance(inputs[1], list):
-            infer_res = [_infer_value(size, {}) for size in inputs[1]]
-            out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res]
-
         data = inputs[0]
+        out_size = _get_upsample_out_size(inputs, method)
 
-        if len(inputs) > 2:
+        if len(inputs) > 2 and method == "trilinear":

Review comment:
       I think it's fine (I don't want to see `BILINEAR"` or `TRILINEAR`)

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1874,35 +1889,24 @@ def func(x):
             return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
 
         if _is_quantized_tensor(data, prelude):
-            # Torch version > 1.4 changed upsampling API
-            if is_version_greater_than("1.4.0"):
-                num_inputs = 7
-            else:
-                num_inputs = 5
-
-            assert len(inputs) == num_inputs, "Input quant param not found in op inputs"
-
+            # input qparams are manually appended by us
+            assert isinstance(inputs[-2], float)
+            assert isinstance(inputs[-1], int)
             input_scale = _expr.const(inputs[-2])
             input_zero_point = _expr.const(inputs[-1])
             return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func)
+
         return func(data)
 
     return _impl
 
 
 def _upsample3d(method):
     def _impl(inputs, input_types):
-        if isinstance(inputs[1], _expr.Var):
-            out_size = _infer_shape(inputs[1])
-        elif _is_int_seq(inputs[1]):
-            out_size = inputs[1]
-        elif isinstance(inputs[1], list):
-            infer_res = [_infer_value(size, {}) for size in inputs[1]]
-            out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res]
-
         data = inputs[0]
+        out_size = _get_upsample_out_size(inputs, method)
 
-        if len(inputs) > 2:
+        if len(inputs) > 2 and method == "trilinear":

Review comment:
       I think it's fine (I don't want to see `BILINEAR` or `TRILINEAR`)

##########
File path: docker/install/ubuntu_install_onnx.sh
##########
@@ -28,4 +28,4 @@ pip3 install onnxruntime==1.0.0
 # not expose that in the wheel!!!
 pip3 install future
 
-pip3 install torch==1.4.0 torchvision==0.5.0
+pip3 install torch==1.7.0 torchvision==0.8.1

Review comment:
       Thanks, I've update the version numbers in https://github.com/apache/incubator-tvm/pull/6825
   Once the tests pass in the staging branch, I'll update this PR.

##########
File path: docker/install/ubuntu_install_onnx.sh
##########
@@ -28,4 +28,4 @@ pip3 install onnxruntime==1.0.0
 # not expose that in the wheel!!!
 pip3 install future
 
-pip3 install torch==1.4.0 torchvision==0.5.0
+pip3 install torch==1.7.0 torchvision==0.8.1

Review comment:
       Thanks, I've update dthe version numbers in https://github.com/apache/incubator-tvm/pull/6825
   Once the tests pass in the staging branch, I'll update this PR.

##########
File path: docker/install/ubuntu_install_onnx.sh
##########
@@ -28,4 +28,4 @@ pip3 install onnxruntime==1.0.0
 # not expose that in the wheel!!!
 pip3 install future
 
-pip3 install torch==1.4.0 torchvision==0.5.0
+pip3 install torch==1.7.0 torchvision==0.8.1

Review comment:
       Thanks, I've updated the version numbers in https://github.com/apache/incubator-tvm/pull/6825
   Once the tests pass in the staging branch, I'll update this PR.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org