Skip to main content

Paralel Diffusion Dil Modeli — GQA + AdaLN-Zero + Self-Conditioning

Project description

omgformer v2

Paralel Diffusion Dil Modeli — aynı anda tüm tokenleri üretir.

Klasik GPT 256 token için 256 forward pass yapar.
omgformer aynı çıktıyı 8–10 forward pass ile üretir.


Temel Fikir

Adım 0: "Merhaba [MASK] [MASK] [MASK] [MASK] [MASK]"
Adım 1: "Merhaba dünya  [MASK] [MASK] [MASK] [MASK]"
Adım 2: "Merhaba dünya  nasıl  gidiyor [MASK] [MASK]"
Adım 3: "Merhaba dünya  nasıl  gidiyor  bugün   ?"

Her adımda tüm sequence bir forward pass'ten geçer.
Model hem sola hem sağa bakar.
En yüksek güvenli tokenlar açılır.
Python for loop yok — tamamen vektörize.


v2 Yenilikleri

GQA (Grouped Query Attention)

KV head sayısını azalt, Q head sayısını koru.
Büyük modellerde KV cache belleğini 4–8× küçültür.

# 16 Q head, 4 KV head → 4× bellek tasarrufu
cfg = OMGConfig(num_heads=16, num_kv_heads=4)

AdaLN-Zero

Timestep koşullandırması her katmana scale/shift/gate üçlüsüyle enjekte edilir.
Input'a toplamak yerine — DiT mimarisinden ilham.

x_attn = norm(x) * (1 + scale) + shift
x = x + gate.tanh() * attention(x_attn)

Gate'ler sıfır başlatılır → eğitim başında blok = tam geçirgen.

Self-Conditioning

Önceki adımın soft tahminleri bir sonraki adıma beslenir.

Adım N:  logits_n  = model(noisy, self_cond=None)
Adım N+1: logits_n1 = model(noisy, self_cond=soft_embed(logits_n))

Eğitimde %50 ihtimalle aktif edilir.
Aynı step sayısında dramatik kalite artışı.

Absorbing Diffusion

Sadece MASK değil, %10 random token da eklenir (BERT tarzı).
Daha zengin training signal → daha hızlı öğrenme.

%80 → MASK token
%10 → rastgele başka token  ← YENİ
%10 → orijinal token kalır

AMP + Gradient Accumulation

TrainingConfig(
    use_amp=True,
    amp_dtype="bfloat16",   # 1.5-2× hız, yarı bellek
    grad_accum_steps=8,     # 8× büyük efektif batch
)

torch.compile

TrainingConfig(use_compile=True)  # ~30% ek hız

Remasking

Inference'ta düşük güvenli tokenlar yeniden maskelenir.
Ardışık adımlarda daha iyi global tutarlılık.

decoder.generate(prompt, remask_prob=0.05)

Kurulum

pip install -e ".[train,dev]"

Hızlı Başlangıç

from omgformer import OMGConfig, OMGModel, MaskScheduler, ParallelDecoder

cfg = OMGConfig.from_preset("omgformer-base")
model = OMGModel(cfg).eval()

sched = MaskScheduler(
    steps=64,
    mask_token_id=cfg.mask_token_id,
    vocab_size=cfg.vocab_size,
)
decoder = ParallelDecoder(model, sched)

prompt_ids = tokenizer.encode("İstanbul'un tarihi")
out = decoder.generate(
    prompt_ids,
    new_tokens=128,
    steps=10,
    temperature=0.9,
    top_p=0.95,
    remask_prob=0.05,
)
print(tokenizer.decode(out[0]))

Pipeline

from omgformer import pipeline

gen = pipeline("text-generation", model="omgformer-base")
result = gen("Yapay zeka", new_tokens=256, steps=10)
print(result["generated_text"])

Eğitim

from omgformer.training import Trainer, TrainingConfig

train_cfg = TrainingConfig(
    steps=100_000,
    batch_size=32,
    lr=3e-4,
    use_amp=True,
    use_compile=True,
    grad_accum_steps=4,
    self_cond_prob=0.5,
    use_wandb=True,
)

trainer = Trainer(model, sched, train_cfg, get_batch=my_data_loader)
trainer.fit()

Çok GPU (FSDP)

from omgformer.training import wrap_fsdp

model = wrap_fsdp(model, device_id=local_rank)
# torchrun --nproc_per_node=8 train.py

Model Boyutları

Preset Katman Hidden Heads (Q/KV) ~Parametre
tiny 4 256 4/4 ~14M
small 8 512 8/4 ~87M
base 12 768 12/4 ~180M
large 24 1024 16/4 ~600M
xl 24 2048 16/2 ~2.1B
3b 32 2560 32/8 ~3.2B

Hız Karşılaştırması

Model 256 token üretim Yöntem
GPT-2 256 forward pass Otoregresif
omgformer (steps=10) 10 forward pass Paralel diffusion
omgformer (steps=6, SC) 6 forward pass + Self-conditioning

Self-conditioning ile aynı kalite için daha az adım yeterli olur.


Testler

pytest test_omgformer.py -v

Referanslar


Lisans

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

omgformer-2.0.0.tar.gz (26.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

omgformer-2.0.0-py3-none-any.whl (26.5 kB view details)

Uploaded Python 3

File details

Details for the file omgformer-2.0.0.tar.gz.

File metadata

  • Download URL: omgformer-2.0.0.tar.gz
  • Upload date:
  • Size: 26.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.7

File hashes

Hashes for omgformer-2.0.0.tar.gz
Algorithm Hash digest
SHA256 b6b72f12271d78182e0e308f86cb40363d1e0276fe119066fb5752f204310b11
MD5 4b98f325679c47a2694d692e2ffbd90b
BLAKE2b-256 226cf869b80591af9df796c109c647f6917cd033e6efd380bacdd29613bb771f

See more details on using hashes here.

File details

Details for the file omgformer-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: omgformer-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 26.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.7

File hashes

Hashes for omgformer-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 29c4a6f2d96aa13d3fcc7a12f0c7efe9fd0a6deaa166013a90be804b6adea7c4
MD5 494966312bd5f21145793870072438d5
BLAKE2b-256 37f8bea0848a479292bf12fe4c866f97846f8c16d3bfead852fe357802c12686

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page