Paralel Diffusion Dil Modeli — GQA + AdaLN-Zero + Self-Conditioning
Project description
omgformer v2
Paralel Diffusion Dil Modeli — aynı anda tüm tokenleri üretir.
Klasik GPT 256 token için 256 forward pass yapar.
omgformer aynı çıktıyı 8–10 forward pass ile üretir.
Temel Fikir
Adım 0: "Merhaba [MASK] [MASK] [MASK] [MASK] [MASK]"
Adım 1: "Merhaba dünya [MASK] [MASK] [MASK] [MASK]"
Adım 2: "Merhaba dünya nasıl gidiyor [MASK] [MASK]"
Adım 3: "Merhaba dünya nasıl gidiyor bugün ?"
Her adımda tüm sequence bir forward pass'ten geçer.
Model hem sola hem sağa bakar.
En yüksek güvenli tokenlar açılır.
Python for loop yok — tamamen vektörize.
v2 Yenilikleri
GQA (Grouped Query Attention)
KV head sayısını azalt, Q head sayısını koru.
Büyük modellerde KV cache belleğini 4–8× küçültür.
# 16 Q head, 4 KV head → 4× bellek tasarrufu
cfg = OMGConfig(num_heads=16, num_kv_heads=4)
AdaLN-Zero
Timestep koşullandırması her katmana scale/shift/gate üçlüsüyle enjekte edilir.
Input'a toplamak yerine — DiT mimarisinden ilham.
x_attn = norm(x) * (1 + scale) + shift
x = x + gate.tanh() * attention(x_attn)
Gate'ler sıfır başlatılır → eğitim başında blok = tam geçirgen.
Self-Conditioning
Önceki adımın soft tahminleri bir sonraki adıma beslenir.
Adım N: logits_n = model(noisy, self_cond=None)
Adım N+1: logits_n1 = model(noisy, self_cond=soft_embed(logits_n))
Eğitimde %50 ihtimalle aktif edilir.
Aynı step sayısında dramatik kalite artışı.
Absorbing Diffusion
Sadece MASK değil, %10 random token da eklenir (BERT tarzı).
Daha zengin training signal → daha hızlı öğrenme.
%80 → MASK token
%10 → rastgele başka token ← YENİ
%10 → orijinal token kalır
AMP + Gradient Accumulation
TrainingConfig(
use_amp=True,
amp_dtype="bfloat16", # 1.5-2× hız, yarı bellek
grad_accum_steps=8, # 8× büyük efektif batch
)
torch.compile
TrainingConfig(use_compile=True) # ~30% ek hız
Remasking
Inference'ta düşük güvenli tokenlar yeniden maskelenir.
Ardışık adımlarda daha iyi global tutarlılık.
decoder.generate(prompt, remask_prob=0.05)
Kurulum
pip install -e ".[train,dev]"
Hızlı Başlangıç
from omgformer import OMGConfig, OMGModel, MaskScheduler, ParallelDecoder
cfg = OMGConfig.from_preset("omgformer-base")
model = OMGModel(cfg).eval()
sched = MaskScheduler(
steps=64,
mask_token_id=cfg.mask_token_id,
vocab_size=cfg.vocab_size,
)
decoder = ParallelDecoder(model, sched)
prompt_ids = tokenizer.encode("İstanbul'un tarihi")
out = decoder.generate(
prompt_ids,
new_tokens=128,
steps=10,
temperature=0.9,
top_p=0.95,
remask_prob=0.05,
)
print(tokenizer.decode(out[0]))
Pipeline
from omgformer import pipeline
gen = pipeline("text-generation", model="omgformer-base")
result = gen("Yapay zeka", new_tokens=256, steps=10)
print(result["generated_text"])
Eğitim
from omgformer.training import Trainer, TrainingConfig
train_cfg = TrainingConfig(
steps=100_000,
batch_size=32,
lr=3e-4,
use_amp=True,
use_compile=True,
grad_accum_steps=4,
self_cond_prob=0.5,
use_wandb=True,
)
trainer = Trainer(model, sched, train_cfg, get_batch=my_data_loader)
trainer.fit()
Çok GPU (FSDP)
from omgformer.training import wrap_fsdp
model = wrap_fsdp(model, device_id=local_rank)
# torchrun --nproc_per_node=8 train.py
Model Boyutları
| Preset | Katman | Hidden | Heads (Q/KV) | ~Parametre |
|---|---|---|---|---|
| tiny | 4 | 256 | 4/4 | ~14M |
| small | 8 | 512 | 8/4 | ~87M |
| base | 12 | 768 | 12/4 | ~180M |
| large | 24 | 1024 | 16/4 | ~600M |
| xl | 24 | 2048 | 16/2 | ~2.1B |
| 3b | 32 | 2560 | 32/8 | ~3.2B |
Hız Karşılaştırması
| Model | 256 token üretim | Yöntem |
|---|---|---|
| GPT-2 | 256 forward pass | Otoregresif |
| omgformer (steps=10) | 10 forward pass | Paralel diffusion |
| omgformer (steps=6, SC) | 6 forward pass | + Self-conditioning |
Self-conditioning ile aynı kalite için daha az adım yeterli olur.
Testler
pytest test_omgformer.py -v
Referanslar
- MDLM: Masked Diffusion Language Models
- DiT: Scalable Diffusion Transformers
- Self-Conditioning in Diffusion Models
- GQA: Training Generalized Multi-Query Attention
- Mercury: Ultra-Fast Language Models (Inception Labs)
Lisans
Apache-2.0
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file omgformer-2.0.0.tar.gz.
File metadata
- Download URL: omgformer-2.0.0.tar.gz
- Upload date:
- Size: 26.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b6b72f12271d78182e0e308f86cb40363d1e0276fe119066fb5752f204310b11
|
|
| MD5 |
4b98f325679c47a2694d692e2ffbd90b
|
|
| BLAKE2b-256 |
226cf869b80591af9df796c109c647f6917cd033e6efd380bacdd29613bb771f
|
File details
Details for the file omgformer-2.0.0-py3-none-any.whl.
File metadata
- Download URL: omgformer-2.0.0-py3-none-any.whl
- Upload date:
- Size: 26.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
29c4a6f2d96aa13d3fcc7a12f0c7efe9fd0a6deaa166013a90be804b6adea7c4
|
|
| MD5 |
494966312bd5f21145793870072438d5
|
|
| BLAKE2b-256 |
37f8bea0848a479292bf12fe4c866f97846f8c16d3bfead852fe357802c12686
|