You are viewing a plain text version of this content. The canonical link for it is here.
Posted to discuss-archive@tvm.apache.org by wyc0926 via Apache TVM Discuss <no...@discuss.tvm.ai> on 2022/02/09 05:43:47 UTC

[Apache TVM Discuss] [Questions] [Pytorch] The inference results of tvm and pytorch are inconsistent


Hi,

I created a pytorch quantization model. After compiling with tvm, I did inference. The result was inconsistent with pytorch. The strange thing is that this phenomenon occurs sometimes.

my code:
```
import torch
from torch import nn
from torch.quantization import QuantStub, DeQuantStub, get_default_qat_qconfig, convert, prepare_qat
from tvm import relay
import numpy as np
import tvm
from tvm import relay
from tvm.contrib import graph_executor

class AdaptiveAvgPool2d(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.quant(x)
        y = self.pool(x)
        y = self.dequant(y)
        return y

    def fuse_model(self):
        pass

fp32_input = torch.randn(1, 3, 128, 128)
model = AdaptiveAvgPool2d()

BACKEND = "qnnpack"
model.qconfig = get_default_qat_qconfig(BACKEND)

prepare_qat(model, inplace=True)

model.eval()
y = model(fp32_input)
model_int8 = convert(model, inplace=True)

script_module = torch.jit.trace(model, fp32_input).eval()

input_name = "input"
input_infos = [(input_name, ((1, 3, 128, 128), 'float32'))]
img_input = np.random.rand(1, 3, 128, 128).astype(np.float32)

pt_input = torch.from_numpy(img_input)

torch.backends.quantized.engine = 'qnnpack'

with torch.no_grad():
    pt_result = script_module(pt_input)

mod, params = relay.frontend.from_pytorch(script_module, input_infos)

target = "llvm"
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)
module = graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
module.set_input(input_name, img_input)
module.run()
print(pt_result[0].numpy().flatten())
print(module.get_output(0).asnumpy().flatten())
print("compare result: ", np.allclose(pt_result[0].numpy().flatten(), module.get_output(0).asnumpy().flatten(), atol=1e-05))
```
If you run the above code repeatedly, you will find that the comparison result is sometimes true and sometimes false. Why is this?

compare result: True
```
[0.48794654 0.48794654 0.48794654]
[0.48794654 0.48794654 0.48794654]
compare result:  True
```
compare result: False
```
[0.5066974 0.5066974 0.5066974]
[0.47291756 0.47291756 0.47291756]
compare result:  False
```





---
[Visit Topic](https://discuss.tvm.apache.org/t/pytorch-the-inference-results-of-tvm-and-pytorch-are-inconsistent/12048/1) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/ce132048d22259daa9934f87063b51b89cf0ab5547bbcc08df007ab0c1610f5a).

[Apache TVM Discuss] [Questions] [Pytorch] The inference results of tvm and pytorch are inconsistent

Posted by wyc0926 via Apache TVM Discuss <no...@discuss.tvm.ai>.

I analyzed the output results. When the results are inconsistent, the inference results of tvm are always different from the results of pytorch by a value of scale size, so I suspect that pytorch's adaptive_avg_pool2d will round the results, while tvm directly discards the decimal part, When rounding is encountered and when a carry is required, the result of tvm will be smaller than pytorch by a value of scale size.

@masahi What's your opinion on this issue?





---
[Visit Topic](https://discuss.tvm.apache.org/t/pytorch-the-inference-results-of-tvm-and-pytorch-are-inconsistent/12048/2) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/8bfcbe3ac0130bdc93904db8e54e9584818b4f7c75e591be08eb257b03cde69d).

[Apache TVM Discuss] [Questions] [Pytorch] The inference results of tvm and pytorch are inconsistent

Posted by masahi via Apache TVM Discuss <no...@discuss.tvm.ai>.

Interesting, does PyTorch do something like that? It's not obvious to me if we can do this without concern. Would this change make the output of every adaptive avg pool different? What about normal avg pooling?





---
[Visit Topic](https://discuss.tvm.apache.org/t/pytorch-the-inference-results-of-tvm-and-pytorch-are-inconsistent/12048/4) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/bf6c9f76492111915d700367fb696484b78a7f0972cdb6fbdc3c76e3eaf6c282).

[Apache TVM Discuss] [Questions] [Pytorch] The inference results of tvm and pytorch are inconsistent

Posted by wyc0926 via Apache TVM Discuss <no...@discuss.tvm.ai>.

I think in order to ensure the accuracy of the model, rounding is necessary.
```
diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h
index c81c7cda7..467d2f5d8 100644
--- a/include/tvm/topi/nn/pooling.h
+++ b/include/tvm/topi/nn/pooling.h
@@ -386,7 +386,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_
             divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
           }
 
-          return div(pool_sum(indices), divide_factor);
+          return div(pool_sum(indices) + div(divide_factor, 2), divide_factor);
         },
         "tensor", kElementWise);
   } else {
```





---
[Visit Topic](https://discuss.tvm.apache.org/t/pytorch-the-inference-results-of-tvm-and-pytorch-are-inconsistent/12048/3) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/fcebcc38d59935420be72541572296eff23346c0e9c3de151d61b7f41c6195c3).