Skip to main content

A PyTorch implementation of transformer-based language models including GPT architecture for pretraining and fine-tuning

Project description

Language Modeling using Transformers (LMT)

CI Python PyTorch License Code style: ruff

An educational PyTorch library for understanding how modern transformer architectures work -- from attention mechanisms to full language models. Every component is written to be understood, with clear code, detailed docstrings, and mathematical notation that maps directly to the papers.

Features

Attention Mechanisms

Component Description Paper
Multi-Head Attention Standard scaled dot-product attention Vaswani et al., 2017
Grouped Query Attention Shared KV heads for efficiency Ainslie et al., 2023
Sliding Window Attention Local attention with fixed window Beltagy et al., 2020
Multi-Head Latent Attention KV compression + decoupled RoPE DeepSeek-AI, 2024

Feed-Forward Networks

Component Description
SwiGLU Gated FFN with Swish activation (LLaMA, Mixtral)
Mixture of Experts Top-k sparse routing with load balancing loss

Other Components: RMSNorm, Rotary Position Embedding (RoPE)

Model Architectures

Model Key Components
GPT Multi-head attention + GELU FFN + learned position embeddings
LLaMA RMSNorm + RoPE + SwiGLU + GQA
Mixtral LLaMA + MoE FFN + sliding window attention

Installation

pip install pylmt

Or install from source for development:

git clone https://github.com/michaelellis003/LMT.git
cd LMT
pip install uv
uv sync

Quick Start

import torch
from lmt.models.config import ModelConfig
from lmt.models.llama import LLaMA

config = ModelConfig(
    vocab_size=32000,
    embed_dim=512,
    num_heads=8,
    num_kv_heads=4,     # GQA: 4 KV heads shared across 8 query heads
    num_layers=6,
    context_length=1024,
    dropout=0.0,
)

model = LLaMA(config)
x = torch.randint(0, config.vocab_size, (1, 128))
logits = model(x)  # [1, 128, 32000]

Using Individual Layers

from lmt.layers.attention import GroupedQueryAttention
from lmt.layers.ffn import SwiGLU
from lmt.layers.normalization import RMSNorm

norm = RMSNorm(d_model=512)
attn = GroupedQueryAttention(config)
ffn = SwiGLU(d_model=512)

x = torch.randn(1, 64, 512)
x = x + attn(norm(x))  # Pre-norm attention
x = x + ffn(norm(x))   # Pre-norm FFN

Mixture of Experts

from lmt.models.mixtral import Mixtral

config = ModelConfig(
    vocab_size=32000, embed_dim=512, num_heads=8,
    num_kv_heads=4, num_layers=8,
    context_length=2048, window_size=256, dropout=0.0,
)

model = Mixtral(config, num_experts=8, top_k=2)
logits = model(x)
aux_loss = model.aux_loss  # load balancing loss for training

Project Structure

src/lmt/
  layers/
    attention/     # MHA, GQA, Sliding Window, MLA
    ffn/           # SwiGLU, MoE (Router + Experts)
    normalization/ # RMSNorm
    positional/    # RoPE
  models/
    gpt/           # GPT (original decoder-only transformer)
    llama/         # LLaMA (RMSNorm + RoPE + SwiGLU + GQA)
    mixtral/       # Mixtral (LLaMA + MoE + sliding window)
  training/        # Trainer, configs, dataloaders
  tokenizer/       # BPE, naive tokenizers

Development

uv run pytest tests/ -v       # Run tests (171 passing)
uv run ruff check src/ tests/ # Lint
uv run ruff format src/ tests/ # Format
uv run pyright src/            # Type check
uv run mkdocs serve            # Local docs server

License

Apache License 2.0. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pylmt-0.6.0.tar.gz (646.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pylmt-0.6.0-py3-none-any.whl (79.8 kB view details)

Uploaded Python 3

File details

Details for the file pylmt-0.6.0.tar.gz.

File metadata

  • Download URL: pylmt-0.6.0.tar.gz
  • Upload date:
  • Size: 646.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pylmt-0.6.0.tar.gz
Algorithm Hash digest
SHA256 0405d0cceb409c88768f420d59e0a38a0de50b66deebc4c95c304c8ba957d213
MD5 86d0cd9a1febbf738828d6ec8b2cbbfc
BLAKE2b-256 28ef4d9b912bd3b7360a588adcbb81451b95246a7fb98406ff08c262547ccf19

See more details on using hashes here.

File details

Details for the file pylmt-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: pylmt-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 79.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pylmt-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 273b33ebff15d2be5f34597a7e96120b8dbed8aa09719b2acbbd1ebbc07bebea
MD5 edae8f95b238fe61a4aa03eb7bf6a303
BLAKE2b-256 04f15fe26402e667ed81606cef67575581efe7db6a3e26477c9465803a4655fc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page