You are viewing a plain text version of this content. The canonical link for it is here.
Posted to discuss-archive@tvm.apache.org by yuheng huang via Apache TVM Discuss <no...@discuss.tvm.ai> on 2021/11/01 23:44:16 UTC

[Apache TVM Discuss] [Questions] [Pytorch] register_forward_hook support


Pytorch dose support forward hook for `torch.jit.trace(...)`.  

For details, you can check: https://github.com/pytorch/pytorch/issues/34329 and https://github.com/pytorch/pytorch/pull/49544 . 

For usage, there is a test file from Pytorch: https://github.com/pytorch/pytorch/blob/5c23888953d277041b341d38dcd5b2d891619ba4/test/jit/test_hooks.py . 

I personally think that a hook mechanism is useful, as it will be convenient if we can get intermediate output for debugging (and for cases like quantization accuracy checking, as you have mentioned). Pytorch itself does support this feature, however, it seems that we can't do the same thing for TVM for now. I will explain a little bit:

To actually get the intermediate result, one way is to just "print" the intermediate tensor in the hook. You can use `torch.jit.trace` to compile a PyTorch model with print function inside a hooker. However, TVM will give you an error saying that some functions are not implemented:
```
The following operators are not implemented: ['prim::Print']
```

Another way is to create a python class like:

```
class HookRecorder:
    def __init__(self):
        self.recorder = dict() # Get intermediate tensor from the recorder
        self.handlers = list()
    
    def _register_hooker(self, name):
        self.recorder[name] = list()
        def named_hooker(module, input: Tuple[torch.Tensor], output: torch.Tensor):
            self.recorder[name].append(output)
        return named_hooker
    
    def register_hookers(self, target_sub_modules, layer_names):
        for i in range(len(layer_names)):
            module = target_sub_modules[i]
            layer_name = layer_names[i]
            handler = module.register_forward_hook(self._register_hooker(layer_name))
        self.handlers.append(handler)
        
    def remove_handlers(self):
        for i in self.handlers:
            i.remove()
        self.handlers.clear()
        
    def __del__(self):
        self.remove_handlers()

hook = HookRecorder()
hook.register_hookers([net.conv2], ["conv2"])
out = net(input)
print(hook.recorder)
```

In this way, we can indeed get intermediate values from the python class. However, this can not be compiled by `torch.jit.trace`.





---
[Visit Topic](https://discuss.tvm.apache.org/t/pytorch-register-forward-hook-support/11036/7) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/13f6bc6341612c9a6f04ef79e2bd9dcc38e82c613206b4c5557069c157d7ee0a).