You are viewing a plain text version of this content. The canonical link for it is here.
Posted to discuss-archive@tvm.apache.org by Ligeng Zhu via Apache TVM Discuss <no...@discuss.tvm.ai> on 2021/11/03 15:03:19 UTC

[Apache TVM Discuss] [Questions] Relay failed to build models exported from pytorch


Hi there,

While I was following the tutorial [relay quick start](https://tvm.apache.org/docs/tutorial/relay_quick_start.html), I tried to load a module from pytorch but it raises segmentation fault error. The TVM I am using the latest commit `bff98843bef9a312587aaff51b679d9b69a7d5a7` and the code to reproduce is attached below

```
from tvm import relay
import tvm
import numpy as np

import torch
import torch as th
import torch.nn as nn
from torchvision import models

import torch.onnx 
from tvm import relay, auto_scheduler

model = nn.Sequential(
    nn.Conv2d(3, 3, kernel_size=3, padding=1),
    nn.BatchNorm2d(3),
    # nn.Dropout()
)

input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
input_name = "input0"
shape_list = [(input_name, input_data.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
print(mod['main'])

opt_level = 3
# target = tvm.target.cuda()
target = "llvm"
with tvm.transform.PassContext(opt_level=opt_level):
    lib = relay.build(mod, target=target, params=params)
```





---
[Visit Topic](https://discuss.tvm.apache.org/t/relay-failed-to-build-models-exported-from-pytorch/11394/1) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/bf147477f2eac3b264299539391f25aadafee0426a6de59a42beb441de850a9e).

[Apache TVM Discuss] [Questions] Relay failed to build models exported from pytorch

Posted by Ligeng Zhu via Apache TVM Discuss <no...@discuss.tvm.ai>.

Notice one very interesting one that might be helpful for this bug

I tried to compile the models exported from mxnet

```
from tvm import relay
import tvm
import numpy as np

import torch
import torch as th
import torch.nn as nn
from torchvision import models
import torch.onnx 

from tvm import relay, auto_scheduler
from tvm.relay import testing


mod, params = relay.testing.resnet.get_workload(
    num_layers=18, batch_size=1, image_shape=(3, 224, 224)
)

opt_level = 3
# target = tvm.target.cuda()
target = "llvm"
with tvm.transform.PassContext(opt_level=opt_level):
    lib = relay.build(mod, target=target, params=params)

print("build sucessful")
```

and it also raises the same segment fault error. However, if I comment all pytorch imports 

```
import torch
import torch as th
import torch.nn as nn
from torchvision import models
import torch.onnx 
```

then the `relay.build` works without problem. Any thoughts on the strange behavior?





---
[Visit Topic](https://discuss.tvm.apache.org/t/relay-failed-to-build-models-exported-from-pytorch/11394/4) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/d08386f0c9978c8817448d18fb86cd0fddff9fbcbfef2465a0a157b551158a87).

[Apache TVM Discuss] [Questions] Relay failed to build models exported from pytorch

Posted by Ligeng Zhu via Apache TVM Discuss <no...@discuss.tvm.ai>.

[quote="lhutton1, post:3, topic:11394"]
elegant
[/quote]

Oh! Thanks for the information. You saved my day.





---
[Visit Topic](https://discuss.tvm.apache.org/t/relay-failed-to-build-models-exported-from-pytorch/11394/5) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/f01ea76492d3835601cdc52f6cd76170b660b574ab1f906e6d20aac224a23557).

[Apache TVM Discuss] [Questions] Relay failed to build models exported from pytorch

Posted by Luke Hutton via Apache TVM Discuss <no...@discuss.tvm.ai>.

Hi @Lyken17,

I also ran into this issue recently. It turned out to be conflicting symbols between PyTorch and TVM, see https://github.com/apache/tvm/issues/9362#issuecomment-955263494 for the resolution. Alternatively, a quicker (but less elegant) solution is to import `torch` before `tvm`. Hope this helps!





---
[Visit Topic](https://discuss.tvm.apache.org/t/relay-failed-to-build-models-exported-from-pytorch/11394/3) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/6ae5a76cb88c2793ede79a351a20be3c0341376ac39e1980baa164e86853d8f6).

[Apache TVM Discuss] [Questions] Relay failed to build models exported from pytorch

Posted by Ligeng Zhu via Apache TVM Discuss <no...@discuss.tvm.ai>.

My environemnt

ubuntu 20.04 | gcc: 9.3 | llvm: 10.0 | nvcc: 11.1





---
[Visit Topic](https://discuss.tvm.apache.org/t/relay-failed-to-build-models-exported-from-pytorch/11394/2) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/797c98cd7ed825a3944ccbeb017c7894d4fe233bd0cfc4e6a616fda0c2b79e5b).