Skip to main content

A tool to capture local variables from any function, especially useful for visualizing attention maps in deep learning models

Project description

AnyCapture

AnyCapture License: Apache 2.0 Downloads Python Version

AnyCapture是一个Python工具库,专门用于捕获函数执行过程中的局部变量。该库主要致力于解决深度学习模型中间结果提取的技术难题,特别适用于深度学习模型中Attention Map的可视化分析。

✨ 核心特性

  • 🚀 多变量捕获:支持通过装饰器同时捕获多个局部变量
  • 📦 字典缓存:变量以结构化字典形式存储,便于管理和访问
  • 🧹 缓存管理:提供clear()方法进行缓存清理
  • 🔄 队列功能:支持限制缓存大小,自动管理内存使用

背景与动机

在深度学习模型可视化过程中,开发者经常遇到以下技术挑战:

传统解决方案的局限性:

  • 返回值传递法:需要修改模型结构,将嵌套在模型深处的Attention Map逐层返回,在训练时又需要还原代码
  • 全局变量法:使用全局变量直接记录Attention Map,容易在训练时遗忘修改导致内存溢出

这些问题在实际开发中普遍存在,严重影响了开发效率。

PyTorch Hook机制的技术限制:

虽然PyTorch提供了hook机制来获取中间结果:

handle = net.conv2.register_forward_hook(hook)

但在实际应用中存在以下技术障碍:

以Vision Transformer为例,其典型结构如下:

class VisionTransformer(nn.Module):
    def __init__(self, *args, **kwargs):
        ...
        self.blocks = nn.Sequential(*[Block(...) for i in range(depth)])
        ...

每个Block中包含Attention模块:

class Block(nn.Module):
    def __init__(self, *args, **kwargs):
        ...
        self.attn = Attention(...)
        ...

Hook机制的技术挑战:

  1. 模块路径复杂:深度嵌套的模块结构导致准确定位目标模块困难
  2. 批量注册繁琐:Transformer中每层都包含attention map,逐个注册hook效率低下

AnyCapture的技术优势:

基于上述技术分析,AnyCapture提供了一种更为简洁高效的解决方案,具备以下核心特性:

  • 🎯 精准定位:支持按变量名精确捕获模型中间结果
  • 多变量支持:装饰器支持同时捕获多个目标变量
  • 🚀 高效便捷:可批量获取Transformer模型中所有层的attention map
  • 🔄 非侵入式设计:无需修改现有函数代码
  • 🎯 开发友好:可视化分析完成后无需修改训练代码

安装指南

使用pip安装AnyCapture:

pip install AnyCapture

使用指南

安装完成后,通过get_local装饰器可以便捷地捕获函数内部的局部变量。

基础用法:单变量捕获

以捕获attention_map变量为例:

步骤1:在模型文件中添加装饰器

from anycapture import get_local

@get_local('attention_map')
def your_attention_function(*args, **kwargs):
    ...
    attention_map = ... 
    ...
    return ...

步骤2:在分析代码中激活装饰器并获取结果

from anycapture import get_local

get_local.activate()  # 激活装饰器
from ... import model  # 注意:模型导入必须在装饰器激活之后

# 加载模型和数据
...
output = model(data)

# 获取捕获的变量
cache = get_local.cache  # 输出格式:{'your_attention_function.attention_map': [attention_map]}

捕获结果以字典形式存储在get_local.cache中,键值格式为函数名.变量名,对应值为变量值列表。

基本功能

# 查看缓存内容
print(get_local.cache)

# 清空缓存
get_local.clear()

# 激活/取消激活
get_local.activate()      # 激活捕获
get_local.deactivate()    # 取消激活,提高性能

# 队列功能:限制缓存大小
get_local.activate(max_size=10)  # 只保留最近10次结果
get_local.set_size(5)  # 动态调整为5个元素

详细文档请参考:DOC.md | demo.ipynb | 更新日志

可视化案例

以下展示了使用AnyCapture对Vision Transformer小型模型(vit_small)进行可视化分析的部分结果。完整案例请参考 demo.ipynb

由于标准Vision Transformer的所有Attention Map均在Attention.forward方法中计算,仅需对该方法添加装饰器,即可批量提取模型12层Transformer的全部Attention Map数据。

单个Attention Head可视化结果:

a head

单层全部Attention Heads可视化结果:

heads

网格级别Attention Map可视化:

grid2grid

版权信息

原始作者: luo3300612
原始项目: Visualizer
当前维护者: zzaiyan

本项目基于luo3300612的Visualizer项目进行重构和功能扩展。为避免与PyPI现有软件包的命名冲突,项目重命名为AnyCapture。特此对原作者的卓越贡献表示诚挚感谢。

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

anycapture-0.1.3.tar.gz (286.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

anycapture-0.1.3-py3-none-any.whl (374.1 kB view details)

Uploaded Python 3

File details

Details for the file anycapture-0.1.3.tar.gz.

File metadata

  • Download URL: anycapture-0.1.3.tar.gz
  • Upload date:
  • Size: 286.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for anycapture-0.1.3.tar.gz
Algorithm Hash digest
SHA256 a2a02e621aeb7f62f4638dcf1b6c2a5517d510e5650347cb59b35d114d9623b1
MD5 4aca07a7969c014fe8e87dc27fa91d02
BLAKE2b-256 837253262b31433a92285685d70d33f1531c7ae03dc973d9823470befc3bc77c

See more details on using hashes here.

File details

Details for the file anycapture-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: anycapture-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 374.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for anycapture-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 1755f2b90f60f3d35775e9cb153911b045b3d165606e2de5a3c59c864d81971a
MD5 a5bdba80035a4ce5243ed2086c35d880
BLAKE2b-256 2ace2a403a478156c97fb0d2e3bffb606946cf6e8afb6f1bad51c294059f464c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page