You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/25 05:17:25 UTC
[GitHub] [tvm] echuraev opened a new issue #7530: [Bug] Error in constant folding when model has dropout operator
echuraev opened a new issue #7530:
URL: https://github.com/apache/tvm/issues/7530
I have one problem with constant folding during conversion `detr` model. When I call `bind_params_by_name` and apply `FoldConstant` pass for the module, then I get the following error message:
```
Traceback (most recent call last):
File "detr_reproducer.py", line 53, in <module>
lib = convert_pytorch_model(model)
File "detr_reproducer.py", line 42, in convert_pytorch_model
mod = seq(mod) # The problem is here
File "/Users/echuraev/Workspace/OctoML/tvm/python/tvm/ir/transform.py", line 130, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "tvm/_ffi/_cython/./packed_func.pxi", line 322, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) 9 libtvm.dylib 0x0000000127d9d68e tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) + 350
[bt] (7) 8 libtvm.dylib 0x0000000127d9dda8 tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>*) const + 440
[bt] (6) 7 libtvm.dylib 0x0000000127da129d tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::InitVTable()::'lambda4'(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>*)::__invoke(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>*) + 29
[bt] (5) 6 libtvm.dylib 0x0000000127da12e5 tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::InitVTable()::'lambda4'(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>*)::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>*) const + 53
[bt] (4) 5 libtvm.dylib 0x0000000129a54e61 tvm::relay::IndexedForwardGraph::Creator::VisitExpr_(tvm::relay::CallNode const*) + 705
[bt] (3) 4 libtvm.dylib 0x00000001298773c0 tvm::AttrRegistryMap<tvm::Op, int>::operator[](tvm::Op const&) const + 32
[bt] (2) 3 libtvm.dylib 0x00000001277f5e3b tvm::AttrRegistryMapContainerMap<tvm::Op>::operator[](tvm::Op const&) const + 635
[bt] (1) 2 libtvm.dylib 0x0000000127488d95 dmlc::LogMessageFatal::~LogMessageFatal() + 21
[bt] (0) 1 libtvm.dylib 0x000000012748c2c1 dmlc::LogMessageFatal::~LogMessageFatal() + 65
File "/Users/echuraev/Workspace/OctoML/tvm/include/tvm/node/attr_registry_map.h", line 63
TVMError:
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
Check failed: idx < data_.size() && data_[idx].second != 0 == false: Attribute TOpPattern has not been registered for nn.dropout
```
I can remove `dropout` operator by running `SimplifyInference` pass before `FoldConstant`. But `FoldConstant` is externally available transformation, and any function with dropout op will lead to crash of it. This is very suspicious and looks like an issue.
I prepared a simple script for reproducing this problem:
```python
import torch
import tvm
import tvm.testing
from tvm import relay
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
target = "llvm"
class TraceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, inp):
out = self.model(inp)
return (out['pred_logits'], out['pred_boxes'])
def convert_pytorch_model(model):
torch.set_grad_enabled(False)
model = TraceWrapper(model.eval())
model.eval()
inp = torch.rand(1, 3, 20, 20)
with torch.no_grad():
trace = torch.jit.trace(model, inp)
mod, params = relay.frontend.from_pytorch(trace, [("input0", inp.shape)])
with tvm.transform.PassContext(opt_level=3):
mod["main"] = bind_params_by_name(mod["main"], params)
seq = tvm.transform.Sequential(
[
transform.InferType(),
transform.FoldConstant(),
]
)
mod = seq(mod) # The problem is here
lib = relay.build(mod, target=target, params=params)
return lib
if __name__ == '__main__':
# Pytorch model to use
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
### TVM portion ###
lib = convert_pytorch_model(model)
lib_name = "out.tar"
lib.export_library(lib_name)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] masahi commented on issue #7530: [Bug] Error in constant folding when model has dropout operator
Posted by GitBox <gi...@apache.org>.
masahi commented on issue #7530:
URL: https://github.com/apache/tvm/issues/7530#issuecomment-786398027
Ok I will take a look at this
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] masahi commented on issue #7530: [Bug] Error in constant folding when model has dropout operator
Posted by GitBox <gi...@apache.org>.
masahi commented on issue #7530:
URL: https://github.com/apache/tvm/issues/7530#issuecomment-786465121
We can add `SimplifyInference` to prereq passes list in `FoldConstant`: https://github.com/apache/tvm/blob/4e211a735221a9b9d188422025e2d464e37b3c96/src/relay/transforms/fold_constant.cc#L386
If we do this,
```
return CreateFunctionPass(pass_func, 2, "FoldConstant", {"SimplifyInference"});
```
`SimplifyInference` will automatically run before `FoldConstant`. But then `SimplifyInference` would run too many times.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] echuraev commented on issue #7530: [Bug] Error in constant folding when model has dropout operator
Posted by GitBox <gi...@apache.org>.
echuraev commented on issue #7530:
URL: https://github.com/apache/tvm/issues/7530#issuecomment-787712141
@masahi Thank you. I think it is the right fix for this problem.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] vinx13 commented on issue #7530: [Bug] Error in constant folding when model has dropout operator
Posted by GitBox <gi...@apache.org>.
vinx13 commented on issue #7530:
URL: https://github.com/apache/tvm/issues/7530#issuecomment-786920775
We can ensure that every op has `TOpPattern` registered
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] masahi closed issue #7530: [Bug] Error in constant folding when model has dropout operator
Posted by GitBox <gi...@apache.org>.
masahi closed issue #7530:
URL: https://github.com/apache/tvm/issues/7530
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] masahi commented on issue #7530: [Bug] Error in constant folding when model has dropout operator
Posted by GitBox <gi...@apache.org>.
masahi commented on issue #7530:
URL: https://github.com/apache/tvm/issues/7530#issuecomment-787633482
@echuraev I confirmed that https://github.com/apache/tvm/issues/7530 fixes this. But note that this solutions leaves `dropout` after foldconstant, if that is not desired you need to run `SimplifyInference`.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org