Skip to main content

A PyTorch implementation of transformer-based language models including GPT architecture for pretraining and fine-tuning

Project description

Language Modeling using Transformers (LMT)

CI Python PyTorch License Code style: ruff

An educational PyTorch library for understanding how modern transformer architectures work -- from attention mechanisms to full language models. Every component is written to be understood, with clear code, detailed docstrings, and mathematical notation that maps directly to the papers.

Features

Attention Mechanisms

Component Description Paper
Multi-Head Attention Standard scaled dot-product attention Vaswani et al., 2017
Grouped Query Attention Shared KV heads for efficiency Ainslie et al., 2023
Sliding Window Attention Local attention with fixed window Beltagy et al., 2020
Multi-Head Latent Attention KV compression + decoupled RoPE DeepSeek-AI, 2024

Feed-Forward Networks

Component Description
SwiGLU Gated FFN with Swish activation (LLaMA, Mixtral)
Mixture of Experts Top-k sparse routing with load balancing loss

Other Components: RMSNorm, Rotary Position Embedding (RoPE)

Model Architectures

Model Key Components
GPT Multi-head attention + GELU FFN + learned position embeddings
LLaMA RMSNorm + RoPE + SwiGLU + GQA
Mixtral LLaMA + MoE FFN + sliding window attention

Installation

pip install pylmt

Or install from source for development:

git clone https://github.com/michaelellis003/LMT.git
cd LMT
pip install uv
uv sync

Quick Start

import torch
from lmt.models.config import ModelConfig
from lmt.models.llama import LLaMA

config = ModelConfig(
    vocab_size=32000,
    embed_dim=512,
    num_heads=8,
    num_kv_heads=4,     # GQA: 4 KV heads shared across 8 query heads
    num_layers=6,
    context_length=1024,
    dropout=0.0,
)

model = LLaMA(config)
x = torch.randint(0, config.vocab_size, (1, 128))
logits = model(x)  # [1, 128, 32000]

Using Individual Layers

from lmt.layers.attention import GroupedQueryAttention
from lmt.layers.ffn import SwiGLU
from lmt.layers.normalization import RMSNorm

norm = RMSNorm(d_model=512)
attn = GroupedQueryAttention(config)
ffn = SwiGLU(d_model=512)

x = torch.randn(1, 64, 512)
x = x + attn(norm(x))  # Pre-norm attention
x = x + ffn(norm(x))   # Pre-norm FFN

Mixture of Experts

from lmt.models.mixtral import Mixtral

config = ModelConfig(
    vocab_size=32000, embed_dim=512, num_heads=8,
    num_kv_heads=4, num_layers=8,
    context_length=2048, window_size=256, dropout=0.0,
)

model = Mixtral(config, num_experts=8, top_k=2)
logits = model(x)
aux_loss = model.aux_loss  # load balancing loss for training

Project Structure

src/lmt/
  layers/
    attention/     # MHA, GQA, Sliding Window, MLA
    ffn/           # SwiGLU, MoE (Router + Experts)
    normalization/ # RMSNorm
    positional/    # RoPE
  models/
    gpt/           # GPT (original decoder-only transformer)
    llama/         # LLaMA (RMSNorm + RoPE + SwiGLU + GQA)
    mixtral/       # Mixtral (LLaMA + MoE + sliding window)
  training/        # Trainer, configs, dataloaders
  tokenizer/       # BPE, naive tokenizers

Development

uv run pytest tests/ -v       # Run tests (171 passing)
uv run ruff check src/ tests/ # Lint
uv run ruff format src/ tests/ # Format
uv run pyright src/            # Type check
uv run mkdocs serve            # Local docs server

License

Apache License 2.0. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pylmt-0.4.1.tar.gz (604.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pylmt-0.4.1-py3-none-any.whl (60.2 kB view details)

Uploaded Python 3

File details

Details for the file pylmt-0.4.1.tar.gz.

File metadata

  • Download URL: pylmt-0.4.1.tar.gz
  • Upload date:
  • Size: 604.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pylmt-0.4.1.tar.gz
Algorithm Hash digest
SHA256 b3036cfcbc711b64171291eea57065fae58d81fced16e587b5aacafaf8edd197
MD5 43ff8119e3fb655a8774e346013ae8ef
BLAKE2b-256 b893ce2d96f70d6e5b85f527d439d8de06b443a397e0efaf544cd88bea6e7b26

See more details on using hashes here.

File details

Details for the file pylmt-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: pylmt-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 60.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pylmt-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a5219b5c78dd9189be3d80080357440bd57674ec2146cd8edefbb0601782b7c9
MD5 341e44a4d6d4135369a57600c68145ac
BLAKE2b-256 e60f07d01870623715d2166b97724ee4ec5818bbbecee7d9ce5dac68865201f2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page