Pytorch Hook 函数

Pytorch中带了Hook函数,Hook的中文意思是’钩子‘,刚开始看到这个词语就有点害怕,一是不认识这个词,翻译成中文也不了解这是什么意思;二是常规调库搭积木时也没有用到过这个函数;直到读到下面文章,https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca 我对hook有了初步的理解

1. 为什么需要 hook 函数

  • 当我们的神经网络出现 bug 时,没法产生我们所期望的输出时,我们通常需要进行debug,一般的做法是在 forward 函数中写 print函数,输出某些层的输出;或者通过添加断点来进行单步调试,以观察中间层的输出。这在 pytorch 中就可以通过 hook 函数来实现。
  • 由于pytorhc的自动求导机制,即当设置参数的 requires_grad=True时,那么涉及这组参数的一系列操作将会被autograd记录用以反向求导。但是在自动求导机制中只保存叶子节点,也就是中间变量在计算完成梯度后会自动释放以节省空间
x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
y = x * 2
z = torch.mean(y)
z.backward()
print("x.grad =", x.grad)
print("y.grad =", y.grad)
print("z.grad =", z.grad)

输出

x.grad = tensor([1., 1.])
y.grad = None
z.grad = None

因此,如果我们想知道 y 和 z 的梯度,就需要用到 hook 函数。
也就是说,hook 函数用以获取我们不方便获得的一些中间变量。

2. 什么是hook函数

  • hook 其实就是一个普通的函数或类,准确的说是一个可调用的对象,callable object. 需要什么样的功能我们可根据自己的需求自己写。总之,hook 和我们常规写的函数和类没有区别。但是 pytorch 有一个机制,我们可以把写好的函数或者类注册到某些 layer (nn.Module)上,这样子当这些 layer 在执行 forward 或者 backward时其输入或输出就会自动传到我们写好的hook函数中执行。因此,这些函数就像一个钩子一样,可以挂到某些layer上或者从这些 layer 上解挂。这就是名字叫 hook 的原因。

3. Pytorch 提供的 Hook

  • 一般来说,我们在 debug 时想知道的内容有三种
    • 某个模块的输入是什么,即 在跑 forward前模块的输入
    • 某个模块的输出是什么,即 在跑 forward后模块的输出
    • 某个模块的梯度反传后是什么,即 在跑 backward后模块的状态
  • 将这三个状态的数据与我们所期望的数据进行比较,我们就可以知道哪里出现了问题;Pytorch 就提供了这三种钩子,把这三种钩子挂到指定的layer上,这些layer的输入输出就会对应的作为参数传到hook函数中运行hook函数。下图引用自
    image.png
  • pytorch nn.Module源码中就提供了这三个属性
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
  • 同时提供了三个注册方法,也就是往上面三个dict中填值的方法
    • forward prehook (executing before the forward pass),
    • forward hook (executing after the forward pass),
    • backward hook (executing after the backward pass).

register_forward_pre_hookforward前运行,获取这一个 module 的输入

    def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
        r"""Registers a forward pre-hook on the module.

        The hook will be called every time before :func:`forward` is invoked.
        It should have the following signature::

            hook(module, input) -> None or modified input

        The input contains only the positional arguments given to the module.
        Keyword arguments won't be passed to the hooks and only to the ``forward``.
        The hook can modify the input. User can either return a tuple or a
        single modified value in the hook. We will wrap the value into a tuple
        if a single value is returned(unless that value is already a tuple).

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_pre_hooks)
        self._forward_pre_hooks[handle.id] = hook
        return handle

register_forward_hook在forward后运行,获取这个module的input和output信息

    def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
        r"""Registers a forward hook on the module.

        The hook will be called every time after :func:`forward` has computed an output.
        It should have the following signature::

            hook(module, input, output) -> None or modified output

        The input contains only the positional arguments given to the module.
        Keyword arguments won't be passed to the hooks and only to the ``forward``.
        The hook can modify the output. It can modify the input inplace but
        it will not have effect on forward since this is called after
        :func:`forward` is called.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_hooks)
        self._forward_hooks[handle.id] = hook
        return handle

register_backward_hook,获取反向传播中module的grad_in, grad_out信息

    def register_backward_hook(
        self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
    ) -> RemovableHandle:
        r"""Registers a backward hook on the module.

        This function is deprecated in favor of :meth:`nn.Module.register_full_backward_hook` and
        the behavior of this function will change in future versions.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``

        """
        if self._is_full_backward_hook is True:
            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "
                               "single Module. Please use only one of them.")

        self._is_full_backward_hook = False

        handle = hooks.RemovableHandle(self._backward_hooks)
        self._backward_hooks[handle.id] = hook
        return handle

4.hook 实例

这里我们通过在ResNet34的每一层插入一个钩子,来获取ResNet34每一层的输出,即这里我们使用 register_forward_hook
使用下面图片作为输入

image.png

import torch
from torchvision.models import resnet34

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = resnet34(pretrained=True)
model = model.to(device)
class SaveOutput:
    def __init__(self):
        self.outputs = []
        self.inputs = []
        
    def __call__(self, module, module_in, module_out):
        print(module)
        self.inputs.append(module_in)
        self.outputs.append(module_out)
        
    def clear(self):
        self.outputs = []
        self.inputs = []
        

save_output = SaveOutput()

hook_handles = []

for layer in model.modules():
    if isinstance(layer, torch.nn.modules.conv.Conv2d):
        handle = layer.register_forward_hook(save_output)
        hook_handles.append(handle)
        
        
from PIL import Image
from torchvision import transforms as T

img = Image.open('./cat.jpeg')
transform = T.Compose([T.Resize((224,224)),
                       T.ToTensor(),
                       T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.485, 0.456, 0.406],)
                      ])
x = transform(img).unsqueeze(0).to(device)
out = model(x)

输出

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

> save_output.outputs[0].size()
torch.Size([1, 64, 112, 112])
> save_output.inputs[0][0].size()
torch.Size([1, 3, 224, 224])

可以看到模块,模块的输入输出会自动作为参数传入到我们写的SaveOutput实例中并调用该实例。
下面是每一层的输出可视化

image.png


对于 Tensor的 hook

x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
y = x * 2
y.register_hook(print)
z = torch.mean(y)
z.backward()

输出:

tensor([0.5000, 0.5000])

hook 应用于 模型剪枝 model pruning
https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,271评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,275评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,151评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,550评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,553评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,559评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,924评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,580评论 0 257
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,826评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,578评论 2 320
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,661评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,363评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,940评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,926评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,156评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,872评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,391评论 2 342

推荐阅读更多精彩内容