Skip to main content

以配置文件的方式构建transformer模型

Project description

OSC-Transformers

🚀 Configuration-driven Modular Transformer Model Building Framework

Python PyTorch License

Flexible, efficient, and extensible Transformer model building tools

中文文档: README-zh.md

✨ Features

  • 🔧 Configuration Driven: Build Transformer models through simple configuration files
  • 🧩 Modular Design: Support custom registration of various components
  • High Performance: Support CUDA Graph and Paged Attention

🛠️ Supported Components

Component Type Built-in Implementation
Attention Mechanism PagedAttention
Feedforward Network SwiGLU
Normalization RMSNorm
Embedding Layer VocabEmbedding
Output Head LMHead

📦 Installation

  • Install latest version PyTorch
  • Install flash-attn: It is recommended to download the official pre-built whl package to avoid compilation issues
  • Install osc-transformers
pip install osc-transformers

🚀 Quick Start

Create model.cfg(Qwen3-0.6B):

[model]
@architecture = "TransformerDecoder"
num_layers = 28
prenorm = "True"

[model.attention]
@attention = "PagedAttention"
in_dim = 1024
num_heads = 16
head_dim = 128
num_query_groups = 8
rope_base = 1000000
q_bias = "False"
k_bias = "False"
v_bias = "False"
o_bias = "False"

[model.attention.k_norm]
@normalization = "RMSNorm"
in_dim = 128
eps = 0.000001

[model.attention.q_norm]
@normalization = "RMSNorm"
in_dim = 128
eps = 0.000001

[model.embedding]
@embedding = "VocabEmbedding"
num_embeddings = 151936
embedding_dim = 1024

[model.feedforward]
@feedforward = "SwiGLU"
in_dim = 1024
hidden_dim = 3072
up_bias = "False"
gate_bias = "False"
down_bias = "False"

[model.head]
@head = "LMHead"
in_dim = 1024
out_dim = 151936
bias = "False"

[model.norm]
@normalization = "RMSNorm"
in_dim = 1024
eps = 0.000001

Code example:

from osc_transformers import TransformerDecoder, Sequence, SamplingParams

# Build model
model = TransformerDecoder.from_config("model.cfg")
model.setup(gpu_memory_utilization=0.9, max_model_len=40960, device="cuda:0")

# Batch inference
seqs = [Sequence(token_ids=[1,2,3,4,5,6,7,8,9,10], sampling_params=SamplingParams(temperature=0.5, max_generate_tokens=1024))]
seqs = model.batch(seqs)

# Streaming inference
seq = Sequence(token_ids=[1,2,3,4,5,6,7,8,9,10], sampling_params=SamplingParams(temperature=0.5, max_generate_tokens=1024))
for token in model.stream(seq):
    pass

📚 Inference Performance

osc-transformers bench examples/configs/qwen3-0_6B.cfg --num_seqs 64 --max_input_len 1024 --max_output_len 1024 --gpu_memory_utilization 0.9
Architecture Model Device Throughput(tokens/s)
TransformerDecoder Qwen3-0.6B 4090 5400
TransformerDecoder Qwen3-0.6B 3090 4000

📚 Acknowledgments

The core code of this project mainly comes from the following projects:

🤝 Contributing

Welcome to submit Issue and Pull Request!

📄 License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

osc_transformers-0.3.0.tar.gz (28.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

osc_transformers-0.3.0-py3-none-any.whl (35.7 kB view details)

Uploaded Python 3

File details

Details for the file osc_transformers-0.3.0.tar.gz.

File metadata

  • Download URL: osc_transformers-0.3.0.tar.gz
  • Upload date:
  • Size: 28.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for osc_transformers-0.3.0.tar.gz
Algorithm Hash digest
SHA256 87db3e41a21506cb6d2c6e821eac6e8d9ff5596c72d036b5a36bf59e57bc203c
MD5 5ba7f96f625fe40b34aac55438263341
BLAKE2b-256 9fae3fb4e0389dee0edfcfc87c39ac4d77df407790b6d26eb78564221d673494

See more details on using hashes here.

File details

Details for the file osc_transformers-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for osc_transformers-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a1e6a928d8335f00878f8606c551b787b62192a2e4e02557fc4458e0b8ab9fc3
MD5 7e86c87ae4fa61908d81bea52ad36ebf
BLAKE2b-256 71bcac2fea03365462e45291624e1a9af3e9e41209457800a2f25063f56cf311

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page