Skip to main content

以配置文件的方式构建transformer模型

Project description

OSC-Transformers

🚀 Configuration-driven Modular Transformer Model Building Framework

Python PyTorch License

Flexible, efficient, and extensible Transformer model building tools

中文文档: README-zh.md

✨ Features

  • 🔧 Configuration Driven: Build Transformer models through simple configuration files
  • 🧩 Modular Design: Support custom registration of various components
  • High Performance: Support CUDA Graph and Paged Attention

🛠️ Supported Components

Component Type Built-in Implementation
Attention Mechanism PagedAttention
Feedforward Network SwiGLU
Normalization RMSNorm
Embedding Layer VocabEmbedding
Output Head LMHead

📦 Installation

  • Install latest version PyTorch
  • Install flash-attn: It is recommended to download the official pre-built whl package to avoid compilation issues
  • Install osc-transformers
pip install osc-transformers

🚀 Quick Start

Create model.cfg(Qwen3-0.6B):

[model]
@architecture = "TransformerDecoder"
num_layers = 28
prenorm = "True"

[model.attention]
@attention = "PagedAttention"
in_dim = 1024
num_heads = 16
head_dim = 128
num_query_groups = 8
rope_base = 1000000
q_bias = "False"
k_bias = "False"
v_bias = "False"
o_bias = "False"

[model.attention.k_norm]
@normalization = "RMSNorm"
in_dim = 128
eps = 0.000001

[model.attention.q_norm]
@normalization = "RMSNorm"
in_dim = 128
eps = 0.000001

[model.embedding]
@embedding = "VocabEmbedding"
num_embeddings = 151936
embedding_dim = 1024

[model.feedforward]
@feedforward = "SwiGLU"
in_dim = 1024
hidden_dim = 3072
up_bias = "False"
gate_bias = "False"
down_bias = "False"

[model.head]
@head = "LMHead"
in_dim = 1024
out_dim = 151936
bias = "False"

[model.norm]
@normalization = "RMSNorm"
in_dim = 1024
eps = 0.000001

Code example:

from osc_transformers import TransformerDecoder, Sequence, SamplingParams

# Build model
model = TransformerDecoder.from_config("model.cfg")
model.setup(gpu_memory_utilization=0.9, max_model_len=40960, device="cuda:0")

# Batch inference
seqs = [Sequence(token_ids=[1,2,3,4,5,6,7,8,9,10], sampling_params=SamplingParams(temperature=0.5, max_generate_tokens=1024))]
seqs = model.batch(seqs)

# Streaming inference
seq = Sequence(token_ids=[1,2,3,4,5,6,7,8,9,10], sampling_params=SamplingParams(temperature=0.5, max_generate_tokens=1024))
for token in model.stream(seq):
    pass

📚 Inference Performance

osc-transformers bench examples/decoder.cfg --num_seqs 64 --max_input_len 1024 --max_output_len 1024 --gpu_memory_utilization 0.9
Device Throughput(tokens/s)
4090 5200

📚 Acknowledgments

The core code of this project mainly comes from the following projects:

🤝 Contributing

Welcome to submit Issue and Pull Request!

📄 License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

osc_transformers-0.2.4.tar.gz (22.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

osc_transformers-0.2.4-py3-none-any.whl (28.4 kB view details)

Uploaded Python 3

File details

Details for the file osc_transformers-0.2.4.tar.gz.

File metadata

  • Download URL: osc_transformers-0.2.4.tar.gz
  • Upload date:
  • Size: 22.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for osc_transformers-0.2.4.tar.gz
Algorithm Hash digest
SHA256 f51cf9a72832dca1b13b9bd4c4bb5905e2aa3564b022bc26ba9c221ff7d36534
MD5 143da2373910eb136dd81875c10f4b7a
BLAKE2b-256 bcccdf9b86ea38aa51f3f105af05b024b82007648d278debc900846a1ceb8b16

See more details on using hashes here.

File details

Details for the file osc_transformers-0.2.4-py3-none-any.whl.

File metadata

File hashes

Hashes for osc_transformers-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 d4bebc41d4cb274ee65b359bd206eae1a8641e68cdddf6c747d8a7473471477e
MD5 6f73fb4b8f3738c66a44efed63fa8aef
BLAKE2b-256 304bb250f3599bf266e1628468e47fc9b555387a9d75fc9264d41b5a3a511d27

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page