You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/06/18 03:36:45 UTC

[GitHub] [incubator-tvm] siju-samuel commented on a change in pull request #5839: [Torch][Quantized] Fix converting serialized quantized models

siju-samuel commented on a change in pull request #5839:
URL: https://github.com/apache/incubator-tvm/pull/5839#discussion_r441948779



##########
File path: tests/python/frontend/pytorch/qnn_test.py
##########
@@ -465,3 +477,31 @@ def get_imagenet_input():
         mean abs_diff: 0.054197952
         558 in 1000 raw outputs identical.
         """
+
+
+def test_serialized_modules():
+    ishape = (1, 16, 64, 64)
+    raw_module = AdaptiveAvgPool2d().eval()
+    inp = torch.rand(ishape)
+
+    quantize_model(raw_module, inp)
+    script_module = torch.jit.trace(raw_module, inp).eval()
+
+    fname = "tmp.pt"
+    torch.jit.save(script_module, fname)
+    loaded = torch.jit.load(fname)
+
+    with torch.no_grad():
+        pt_result = loaded(inp.clone()).numpy()
+
+    input_name = "input"
+    runtime = get_tvm_runtime(loaded, input_name, ishape)
+    runtime.set_input(input_name, inp.numpy().copy())
+    runtime.run()
+    tvm_result = runtime.get_output(0).asnumpy()
+
+    num_identical = np.sum(tvm_result == pt_result)
+    match_ratio = num_identical / float(np.prod(tvm_result.shape))
+    assert match_ratio > 0.2
+
+    os.remove(fname)

Review comment:
       This can be moved after loading, otherwise if there is assert, cleanup wont happen.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -595,15 +597,19 @@ def _impl(inputs, input_types):
         return _op.log(_op.tensor.sigmoid(data))
     return _impl
 
-def _adaptive_avg_pool_2d():
+def _adaptive_avg_pool_2d(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
         output_size = _infer_shape(inputs[1])
 
         def func(x):
             return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
 
-        if input_types[0] == "quint8":
+        ty = _infer_type_with_prelude(data, prelude)
+        # If a quantized Torch module is saved and loaded back, dtype will be dropped
+        # input_types[0] can be float even though the input is a quantized tensor
+        # To reliably determine input types, we use Relay's type inference result
+        if ty.dtype == "uint8":

Review comment:
       One suggestion
   Can we have a wrapper for checking whether the data is quantized. so that it will be easy for extending to other ops and in future if pytorch fixes their issue, our modification will be minimal.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org