Skip to main content

A tool to capture local variables from any function, especially useful for visualizing attention maps in deep learning models

Project description

AnyCapture

AnyCapture是一个能够捕获任意函数中局部变量的小工具,主要功能是帮助取出嵌套在模型深处的中间结果,特别适用于深度学习模型中Attention Map的可视化

为什么需要AnyCapture?

为了可视化Attention Map,你是否有以下苦恼

  • Return大法好:通过return将嵌套在模型深处的Attention Map一层层地返回回来,然后训练模型的时候又不得不还原
  • 全局大法好:使用全局变量在Attention函数中直接记录Attention Map,结果训练的时候忘改回来导致OOM

不管你有没有,反正我有

咨询了专业人士的意见后,发现pytorch有个hook可以取出中间结果,大概查了一下,发现确实可以取出中间变量,但需要进行如下类似的hook注册

handle = net.conv2.register_forward_hook(hook)

进行这样操作的前提是我们知道要取出来的模块名,但是Transformer类模型一般是这样定义的(以Vit为例)

class VisionTransformer(nn.Module):
    def __init__(self, *args, **kwargs):
        ...
        self.blocks = nn.Sequential(*[Block(...) for i in range(depth)])
        ...

然后每个Block中都有一个Attention

class Block(nn.Module):
    def __init__(self, *args, **kwargs):
        ...
        self.attn = Attention(...)
        ...

如果要使用hooks其中的问题就是

  1. 嵌套太深,模块名不清晰,我们根本不知道我们要取的attention map怎么以model.bla.bla.bla这样一直点出来!
  2. 一般来说,Transformer中attention map每层都有一个,一个个注册实在太麻烦了

所以我就思考并查找能否通过更简洁的方法来得到Attention Map(尤其是Transformer的),而AnyCapture就是其中的一种,它具有以下特点

  • 精准直接,你可以取出任何变量名的模型中间结果
  • 快捷方便,同时取出Transformer类模型中的所有attention map
  • 非侵入式,你无须修改函数内的任何一行代码
  • 训练-测试一致,可视化完成后,你无须在训练时再将代码改回来

用法

安装

pip install bytecode
python setup.py install

安装完成后,只需要用get_local装饰一下Attention的函数,forward之后就可以拿到函数内与装饰器参数同名的局部变量啦~

Usage1

比如说,我想要函数里的attention_map变量: 在模型文件里,我们这么写

from anycapture import get_local
@get_local('attention_map')
def your_attention_function(*args, **kwargs):
    ...
    attention_map = ... 
    ...
    return ...

然后在可视化代码里,我们这么写

from anycapture import get_local
get_local.activate() # 激活装饰器
from ... import model # 被装饰的模型一定要在装饰器激活之后导入!!

# load model and data
...
out = model(data)

cache = get_local.cache # ->  {'your_attention_function': [attention_map]}

最终就会以字典形式存在get_local.cache里,其中key是你的函数名,value就是一个存储attention_map的列表

Usage2

使用Pytorch时我们往往会将模块定义成一个类,此时也是一样只要装饰类内计算出attention_map的函数即可

from anycapture import get_local

class Attention(nn.Module):
    def __init__(self):
        ...
    
    @get_local('attn_map')
    def forward(self, x):
        ...
        attn_map = ...
        ...
        return ...

其他细节请参考demo.ipynb文件

可视化结果

这里是部分可视化vit_small的结果,全部内容在demo.ipynb文件里

因为普通Vit所有Attention map都是在Attention.forward中计算出来的,所以只要简单地装饰一下这个函数,我们就可以同时取出vit中12层Transformer的所有Attention Map!

一个Head的结果

a head

一层所有Heads的结果

heads

某个grid的Attention Map

grid2grid

注意

  • 想要可视化的变量在函数内部不能被后续的同名变量覆盖了,因为get_local取的是对应名称变量在函数中的最终值
  • 进行可视化时,get_local.activate()一定要在导入模型完成,因为python装饰器是在导入时执行的
  • 训练时你不需要修改/删除任何代码,即不用删掉装饰函数的代码,因为在get_local.activate()没有执行的情况下,attention函数不会被装饰,故没有任何性能损失(同上一点,因为python装饰器是在导入时执行的)

其他

当然,其实get_local本身可以取出任何一个函数中某个局部变量的最终值,所以它应该还有其他更有趣的用途

版权声明

原作者: luo3300612
原项目: Visualizer
当前维护: zzaiyan

本项目基于原作者luo3300612的Visualizer项目进行重构和重命名,为了避免与PyPI上现有的库名称冲突,将项目重命名为AnyCapture。感谢原作者的优秀工作!

references

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

AnyCapture-0.0.2.tar.gz (288.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

AnyCapture-0.0.2-py3-none-any.whl (373.1 kB view details)

Uploaded Python 3

File details

Details for the file AnyCapture-0.0.2.tar.gz.

File metadata

  • Download URL: AnyCapture-0.0.2.tar.gz
  • Upload date:
  • Size: 288.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for AnyCapture-0.0.2.tar.gz
Algorithm Hash digest
SHA256 cfb8599ec597779e6f8aec8def5c31d79e4745cac03e005b3266f5dbe1a9213d
MD5 4a2b9bd3b70ae57cbd29b23b9a67e3ac
BLAKE2b-256 1c449dc8bb3c7eb10d4bf287f4850dd629bd49ba1711fed98c31a83773b377fe

See more details on using hashes here.

File details

Details for the file AnyCapture-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: AnyCapture-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 373.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for AnyCapture-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a4b61f841939c24679e8fc8eb9c9b6d1e9212bcae1283e0d729e8e1be730f851
MD5 39664158e3fb95a2ffbf4dcd773f72cd
BLAKE2b-256 39c454b1725ca0c595ac9e248bfabbf5ecec996e79006c01e442bcad99b27730

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page