以配置文件的方式构建transformer模型
Project description
OSC-Transformers
✨ 主要特性
- 🔧 配置驱动: 通过简单的配置文件构建复杂的 Transformer 模型
- 🧩 模块化设计: 支持自定义注册各种组件(注意力机制、前馈网络、归一化等)
- ⚡ 高性能优化:
- 支持 CUDA Graph 加速
- 内置 Paged Attention 机制
- 高效的内存管理
- 🎯 易于使用: 提供多种构建方式,从简单的 API 到复杂的配置文件
- 🔄 高度可扩展: 基于注册机制,轻松扩展新的模型组件
🛠️ 支持的组件
| 组件类型 | 内置实现 | 描述 |
|---|---|---|
| 注意力机制 | PagedAttention |
高效的分页注意力实现 |
| 前馈网络 | SwiGLU |
SwiGLU 激活函数的前馈网络 |
| 归一化 | RMSNorm |
RMS 归一化层 |
| 嵌入层 | VocabEmbedding |
词汇表嵌入层 |
| 输出头 | LMHead |
语言模型输出头 |
| 采样器 | SimpleSampler |
简单的采样实现 |
📦 安装
环境要求
- Python >= 3.10
- PyTorch >= 2.8.0
安装方式
pip install osc-transformers --upgrade
或从源码安装:
git clone https://github.com/your-repo/osc-transformers.git
cd osc-transformers
pip install -e .
🚀 快速开始
方式一:使用 Builder 模式
from osc_transformers import TransformerDecoderBuilder
# 创建构建器
builder = TransformerDecoderBuilder(num_layers=8, max_length=1024)
# 配置各个组件
embedding_config = '''
[embedding]
@embedding = VocabEmbedding
num_embeddings = 32000
embedding_dim = 1024
'''
builder.set_embedding(config=embedding_config)
attention_config = '''
[attention]
@attention = PagedAttention
in_dim = 1024
num_heads = 16
'''
builder.set_attention(config=attention_config)
# ... 配置其他组件
# 构建模型
model = builder.build()
方式二:使用配置文件
创建配置文件 model.cfg:
[model]
@architecture = "TransformerDecoder"
num_layers = 28
max_length = 40960
prenorm = "True"
[model.attention]
@attention = "PagedAttention"
in_dim = 1024
num_heads = 16
head_dim = 128
num_query_groups = 8
[model.embedding]
@embedding = "VocabEmbedding"
num_embeddings = 32000
embedding_dim = 1024
[model.feedforward]
@feedforward = "SwiGLU"
in_dim = 1024
hidden_dim = 3072
[model.head]
@head = "LMHead"
in_dim = 1024
out_dim = 32000
[model.norm]
@normalization = "RMSNorm"
in_dim = 1024
eps = 1e-6
加载模型:
from osc_transformers import TransformerDecoder
model = TransformerDecoder.form_config(config="model.cfg")
🔧 自定义组件
框架支持注册自定义组件,例如自定义归一化层:
from osc_transformers.normalization import Normalization
from osc_transformers.registry import Registry
import torch
@Registry.normalization.register("LayerNorm")
class LayerNorm(Normalization):
def __init__(self, in_dim: int, eps: float = 1e-5):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(in_dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.layer_norm(x, (x.size(-1),), self.weight, eps=self.eps)
然后在配置中使用:
[model.norm]
@normalization = "LayerNorm"
in_dim = 1024
eps = 1e-5
📚 API 文档
TransformerDecoder
主要的 Transformer 解码器模型类。
参数
num_layers(int): 解码器层数max_length(int): 最大序列长度attention(CausalSelfAttention): 注意力机制embedding(Embedding): 嵌入层feedforward(FeedForward): 前馈网络head(Head): 输出头norm(Normalization): 归一化层prenorm(bool): 是否使用预归一化
方法
form_config(config, model_section="model", empty_init=True): 从配置文件构建模型setup(**kwargs): 设置模型(如缓存等)forward(input_ids, attn_ctx): 前向传播compute_logits(x): 计算输出 logits
TransformerDecoderBuilder
构建器模式的模型构建类。
方法
set_embedding(config, section="embedding"): 设置嵌入层set_attention(config, section="attention"): 设置注意力机制set_feedforward(config, section="feedforward"): 设置前馈网络set_head(config, section="head"): 设置输出头set_norm(config, section="normalization"): 设置归一化层build(): 构建最终模型
🎯 使用场景
- 研究原型: 快速实验不同的 Transformer 架构
- 生产部署: 高性能的推理服务
- 教学演示: 理解 Transformer 内部结构
- 模型定制: 针对特定任务的模型优化
🤝 贡献
欢迎贡献代码!请查看我们的贡献指南:
- Fork 本项目
- 创建特性分支 (
git checkout -b feature/AmazingFeature) - 提交更改 (
git commit -m 'Add some AmazingFeature') - 推送到分支 (
git push origin feature/AmazingFeature) - 打开 Pull Request
📄 许可证
本项目采用 MIT 许可证 - 详见 LICENSE 文件。
🙏 致谢
- 感谢 Confection 提供的配置系统
- 感谢 PyTorch 团队提供的深度学习框架
- 感谢所有贡献者的支持
如果这个项目对您有帮助,请给我们一个 ⭐️
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
osc_transformers-0.2.0.tar.gz
(13.3 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file osc_transformers-0.2.0.tar.gz.
File metadata
- Download URL: osc_transformers-0.2.0.tar.gz
- Upload date:
- Size: 13.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.6.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d6376ef18da8074609baa3b2230c416cb581e7131bee3ce3c3424818da21b36a
|
|
| MD5 |
9419d44807b223b6fb9b0e8c7156d2f2
|
|
| BLAKE2b-256 |
9cb77583a67c6785cd00372d0ad8bb678a988524e3da4cb6279bbff45160ae20
|
File details
Details for the file osc_transformers-0.2.0-py3-none-any.whl.
File metadata
- Download URL: osc_transformers-0.2.0-py3-none-any.whl
- Upload date:
- Size: 16.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.6.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
38ed01050c5158fe9229dcfd21acaec6e005f953abf669479f621d7d72381368
|
|
| MD5 |
346c9f0fbb5b20e00e68ac9d16ed4afe
|
|
| BLAKE2b-256 |
f59a37176bffd53fa2329729272886a43ed0a66c7101b04b97fb163ad3c089b4
|