Skip to main content

A tool to capture local variables from any function, especially useful for visualizing attention maps in deep learning models

Project description

AnyCapture

AnyCapture License: Apache 2.0 Downloads Python Version

AnyCapture是一个Python工具库,专门用于捕获函数执行过程中的局部变量。该库主要致力于解决深度学习模型中间结果提取的技术难题,特别适用于深度学习模型中Attention Map的可视化分析。

✨ 核心特性

  • 🚀 多变量捕获:支持通过装饰器同时捕获多个局部变量
  • 📦 字典缓存:变量以结构化字典形式存储,便于管理和访问
  • 🧹 缓存管理:提供clear()方法进行缓存清理
  • 🔄 队列功能:支持限制缓存大小,自动管理内存使用
  • 🛡️ dtype兼容:遇到bfloat16等NumPy不支持类型时,自动转换为fp32
  • 🧩 自定义整理:支持通过set_collate_fn注入自定义缓存整理逻辑

背景与动机

在深度学习模型可视化过程中,开发者经常遇到以下技术挑战:

传统解决方案的局限性:

  • 返回值传递法:需要修改模型结构,将嵌套在模型深处的Attention Map逐层返回,在训练时又需要还原代码
  • 全局变量法:使用全局变量直接记录Attention Map,容易在训练时遗忘修改导致内存溢出

这些问题在实际开发中普遍存在,严重影响了开发效率。

PyTorch Hook机制的技术限制:

虽然PyTorch提供了hook机制来获取中间结果:

handle = net.conv2.register_forward_hook(hook)

但在实际应用中存在以下技术障碍:

以Vision Transformer为例,其典型结构如下:

class VisionTransformer(nn.Module):
    def __init__(self, *args, **kwargs):
        ...
        self.blocks = nn.Sequential(*[Block(...) for i in range(depth)])
        ...

每个Block中包含Attention模块:

class Block(nn.Module):
    def __init__(self, *args, **kwargs):
        ...
        self.attn = Attention(...)
        ...

Hook机制的技术挑战:

  1. 模块路径复杂:深度嵌套的模块结构导致准确定位目标模块困难
  2. 批量注册繁琐:Transformer中每层都包含attention map,逐个注册hook效率低下

AnyCapture的技术优势:

基于上述技术分析,AnyCapture提供了一种更为简洁高效的解决方案,具备以下核心特性:

  • 🎯 精准定位:支持按变量名精确捕获模型中间结果
  • 多变量支持:装饰器支持同时捕获多个目标变量
  • 🚀 高效便捷:可批量获取Transformer模型中所有层的attention map
  • 🔄 非侵入式设计:无需修改现有函数代码
  • 🎯 开发友好:可视化分析完成后无需修改训练代码

安装指南

使用pip安装AnyCapture:

pip install AnyCapture

使用指南

安装完成后,通过get_local装饰器可以便捷地捕获函数内部的局部变量。

基础用法:单变量捕获

以捕获attention_map变量为例:

步骤1:在模型文件中添加装饰器

from anycapture import get_local

@get_local('attention_map')
def your_attention_function(*args, **kwargs):
    ...
    attention_map = ... 
    ...
    return ...

步骤2:在分析代码中激活装饰器并获取结果

from anycapture import get_local

get_local.activate()  # 激活装饰器
from ... import model  # 注意:模型导入必须在装饰器激活之后

# 加载模型和数据
...
output = model(data)

# 获取捕获的变量
cache = get_local.cache  # 输出格式:{'your_attention_function.attention_map': [attention_map]}

捕获结果以字典形式存储在get_local.cache中,键值格式为函数名.变量名,对应值为变量值列表。

基本功能

# 查看缓存内容
print(get_local.cache)

# 清空缓存
get_local.clear()

# 激活/取消激活
get_local.activate()      # 激活捕获
get_local.deactivate()    # 取消激活,提高性能

# 队列功能:限制缓存大小
get_local.activate(max_size=10)  # 只保留最近10次结果
get_local.set_size(5)  # 动态调整为5个元素

数据类型兼容与自定义collate_fn

AnyCapture默认会在缓存前对Tensor执行detach().cpu();如果遇到NumPy不支持的浮点类型(如bfloat16),会显式调用.float()转换为fp32后再写入缓存。

当你需要更细粒度控制(例如压缩缓存、只保留统计量、转换为特定结构)时,可以优雅地注入自定义collate_fn

from anycapture import get_local

# 自定义:将Tensor转为fp32并仅保留均值,降低缓存体积
def my_collate(value):
    if hasattr(value, 'detach'):
        tensor = value.detach().cpu().float()
        return tensor.mean().item()
    return value

get_local.set_collate_fn(my_collate)

# 恢复默认行为(内置bfloat16兼容逻辑)
get_local.set_collate_fn(None)

详细文档请参考:DOC.md | demo.ipynb | 更新日志

可视化案例

以下展示了使用AnyCapture对Vision Transformer小型模型(vit_small)进行可视化分析的部分结果。完整案例请参考 demo.ipynb

由于标准Vision Transformer的所有Attention Map均在Attention.forward方法中计算,仅需对该方法添加装饰器,即可批量提取模型12层Transformer的全部Attention Map数据。

单个Attention Head可视化结果:

a head

单层全部Attention Heads可视化结果:

heads

网格级别Attention Map可视化:

grid2grid

版权信息

原始作者: luo3300612
原始项目: Visualizer
当前维护者: zzaiyan

本项目基于luo3300612的Visualizer项目进行重构和功能扩展。为避免与PyPI现有软件包的命名冲突,项目重命名为AnyCapture。特此对原作者的卓越贡献表示诚挚感谢。

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

anycapture-0.1.4.tar.gz (287.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

anycapture-0.1.4-py3-none-any.whl (375.1 kB view details)

Uploaded Python 3

File details

Details for the file anycapture-0.1.4.tar.gz.

File metadata

  • Download URL: anycapture-0.1.4.tar.gz
  • Upload date:
  • Size: 287.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.20

File hashes

Hashes for anycapture-0.1.4.tar.gz
Algorithm Hash digest
SHA256 bfb12d71e94c3217c0eb33e62f65a53222aae24119eaacd8511fc34f451553c7
MD5 2bc3ec873d7d9af4f83b918d3b3f1df0
BLAKE2b-256 bff889195e5e1d33dbc18e30b44141ddba81c7cb266ea8265f4fb54e8336cab3

See more details on using hashes here.

File details

Details for the file anycapture-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: anycapture-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 375.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.20

File hashes

Hashes for anycapture-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 eddebeb4fb0761506331113d24cc23448ff17ba4ece979453d9cec9c4afe5c6b
MD5 f7b5088a016f78835fbe5187e9ae0985
BLAKE2b-256 07fca440891005c441377c6ff6d1282efd3e3341dd99ab82d165a282933e2354

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page