Skip to main content

A PyTorch implementation of transformer-based language models including GPT architecture for pretraining and fine-tuning

Project description

Language Modeling using Transformers (LMT)

CI Python PyTorch License Code style: ruff

An educational PyTorch library for understanding how modern transformer architectures work -- from attention mechanisms to full language models. Every component is written to be understood, with clear code, detailed docstrings, and mathematical notation that maps directly to the papers.

Features

Attention Mechanisms

Component Description Paper
Multi-Head Attention Standard scaled dot-product attention Vaswani et al., 2017
Grouped Query Attention Shared KV heads for efficiency Ainslie et al., 2023
Sliding Window Attention Local attention with fixed window Beltagy et al., 2020
Multi-Head Latent Attention KV compression + decoupled RoPE DeepSeek-AI, 2024

Feed-Forward Networks

Component Description
SwiGLU Gated FFN with Swish activation (LLaMA, Mixtral)
Mixture of Experts Top-k sparse routing with load balancing loss

Other Components: RMSNorm, Rotary Position Embedding (RoPE)

Model Architectures

Model Key Components
GPT Multi-head attention + GELU FFN + learned position embeddings
LLaMA RMSNorm + RoPE + SwiGLU + GQA
Mixtral LLaMA + MoE FFN + sliding window attention

Installation

pip install pylmt

Or install from source for development:

git clone https://github.com/michaelellis003/LMT.git
cd LMT
pip install uv
uv sync

Quick Start

import torch
from lmt.models.config import ModelConfig
from lmt.models.llama import LLaMA

config = ModelConfig(
    vocab_size=32000,
    embed_dim=512,
    num_heads=8,
    num_kv_heads=4,     # GQA: 4 KV heads shared across 8 query heads
    num_layers=6,
    context_length=1024,
    dropout=0.0,
)

model = LLaMA(config)
x = torch.randint(0, config.vocab_size, (1, 128))
logits = model(x)  # [1, 128, 32000]

Using Individual Layers

from lmt.layers.attention import GroupedQueryAttention
from lmt.layers.ffn import SwiGLU
from lmt.layers.normalization import RMSNorm

norm = RMSNorm(d_model=512)
attn = GroupedQueryAttention(config)
ffn = SwiGLU(d_model=512)

x = torch.randn(1, 64, 512)
x = x + attn(norm(x))  # Pre-norm attention
x = x + ffn(norm(x))   # Pre-norm FFN

Mixture of Experts

from lmt.models.mixtral import Mixtral

config = ModelConfig(
    vocab_size=32000, embed_dim=512, num_heads=8,
    num_kv_heads=4, num_layers=8,
    context_length=2048, window_size=256, dropout=0.0,
)

model = Mixtral(config, num_experts=8, top_k=2)
logits = model(x)
aux_loss = model.aux_loss  # load balancing loss for training

Project Structure

src/lmt/
  layers/
    attention/     # MHA, GQA, Sliding Window, MLA
    ffn/           # SwiGLU, MoE (Router + Experts)
    normalization/ # RMSNorm
    positional/    # RoPE
  models/
    gpt/           # GPT (original decoder-only transformer)
    llama/         # LLaMA (RMSNorm + RoPE + SwiGLU + GQA)
    mixtral/       # Mixtral (LLaMA + MoE + sliding window)
  training/        # Trainer, configs, dataloaders
  tokenizer/       # BPE, naive tokenizers

Development

uv run pytest tests/ -v       # Run tests (171 passing)
uv run ruff check src/ tests/ # Lint
uv run ruff format src/ tests/ # Format
uv run pyright src/            # Type check
uv run mkdocs serve            # Local docs server

License

Apache License 2.0. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pylmt-0.4.0.tar.gz (600.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pylmt-0.4.0-py3-none-any.whl (60.2 kB view details)

Uploaded Python 3

File details

Details for the file pylmt-0.4.0.tar.gz.

File metadata

  • Download URL: pylmt-0.4.0.tar.gz
  • Upload date:
  • Size: 600.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pylmt-0.4.0.tar.gz
Algorithm Hash digest
SHA256 84a4fe8baf952244d172c90931c8a0a6253c69b98533ea136a369b6aa18f0238
MD5 ef7463fe39218606fc8369073b836b56
BLAKE2b-256 f340732da73fe74486318c8a33262e78bca4edfb4441b2d0f51d0429a433a90b

See more details on using hashes here.

File details

Details for the file pylmt-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: pylmt-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 60.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pylmt-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1f529530482d0875f050e806d2a0fe18e5201ef4c1cb538ebaf6eb11af0b9d21
MD5 6349fb2dd5a7e1a018bf4f05b23f3e15
BLAKE2b-256 d09e6e74be650ac70df3b71392337fefb1623ccd6709e645365912ca17d3428c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page