Skip to main content

A PyTorch implementation of transformer-based language models including GPT architecture for pretraining and fine-tuning

Project description

Language Modeling using Transformers (LMT)

CI Python PyTorch License Code style: ruff

An educational PyTorch library for understanding how modern transformer architectures work -- from attention mechanisms to full language models. Every component is written to be understood, with clear code, detailed docstrings, and mathematical notation that maps directly to the papers.

Features

Attention Mechanisms

Component Description Paper
Multi-Head Attention Standard scaled dot-product attention Vaswani et al., 2017
Grouped Query Attention Shared KV heads for efficiency Ainslie et al., 2023
Sliding Window Attention Local attention with fixed window Beltagy et al., 2020
Multi-Head Latent Attention KV compression + decoupled RoPE DeepSeek-AI, 2024

Feed-Forward Networks

Component Description
SwiGLU Gated FFN with Swish activation (LLaMA, Mixtral)
Mixture of Experts Top-k sparse routing with load balancing loss

Other Components: RMSNorm, Rotary Position Embedding (RoPE)

Model Architectures

Model Key Components
GPT Multi-head attention + GELU FFN + learned position embeddings
LLaMA RMSNorm + RoPE + SwiGLU + GQA
Mixtral LLaMA + MoE FFN + sliding window attention

Installation

pip install pylmt

Or install from source for development:

git clone https://github.com/michaelellis003/LMT.git
cd LMT
pip install uv
uv sync

Quick Start

import torch
from lmt.models.config import ModelConfig
from lmt.models.llama import LLaMA

config = ModelConfig(
    vocab_size=32000,
    embed_dim=512,
    num_heads=8,
    num_kv_heads=4,     # GQA: 4 KV heads shared across 8 query heads
    num_layers=6,
    context_length=1024,
    dropout=0.0,
)

model = LLaMA(config)
x = torch.randint(0, config.vocab_size, (1, 128))
logits = model(x)  # [1, 128, 32000]

Using Individual Layers

from lmt.layers.attention import GroupedQueryAttention
from lmt.layers.ffn import SwiGLU
from lmt.layers.normalization import RMSNorm

norm = RMSNorm(d_model=512)
attn = GroupedQueryAttention(config)
ffn = SwiGLU(d_model=512)

x = torch.randn(1, 64, 512)
x = x + attn(norm(x))  # Pre-norm attention
x = x + ffn(norm(x))   # Pre-norm FFN

Mixture of Experts

from lmt.models.mixtral import Mixtral

config = ModelConfig(
    vocab_size=32000, embed_dim=512, num_heads=8,
    num_kv_heads=4, num_layers=8,
    context_length=2048, window_size=256, dropout=0.0,
)

model = Mixtral(config, num_experts=8, top_k=2)
logits = model(x)
aux_loss = model.aux_loss  # load balancing loss for training

Project Structure

src/lmt/
  layers/
    attention/     # MHA, GQA, Sliding Window, MLA
    ffn/           # SwiGLU, MoE (Router + Experts)
    normalization/ # RMSNorm
    positional/    # RoPE
  models/
    gpt/           # GPT (original decoder-only transformer)
    llama/         # LLaMA (RMSNorm + RoPE + SwiGLU + GQA)
    mixtral/       # Mixtral (LLaMA + MoE + sliding window)
  training/        # Trainer, configs, dataloaders
  tokenizer/       # BPE, naive tokenizers

Development

uv run pytest tests/ -v       # Run tests (171 passing)
uv run ruff check src/ tests/ # Lint
uv run ruff format src/ tests/ # Format
uv run pyright src/            # Type check
uv run mkdocs serve            # Local docs server

License

Apache License 2.0. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pylmt-0.3.0.tar.gz (590.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pylmt-0.3.0-py3-none-any.whl (59.9 kB view details)

Uploaded Python 3

File details

Details for the file pylmt-0.3.0.tar.gz.

File metadata

  • Download URL: pylmt-0.3.0.tar.gz
  • Upload date:
  • Size: 590.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pylmt-0.3.0.tar.gz
Algorithm Hash digest
SHA256 2206619a7e01d602c47bb39b8759ddf7bc8f8e23522871edcc45971271a4387b
MD5 4523c7a5233546af1c7142d888efb045
BLAKE2b-256 515a09121258e3253c49a85c7697c1226f94fcb15b3245ed7cfdd4b069a0719f

See more details on using hashes here.

File details

Details for the file pylmt-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: pylmt-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 59.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pylmt-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c0fc30d6e44d55609f49074dad2c5ffc0340953f79b25620663d87f0e531fefe
MD5 fa05b812a9140684a149bd4429614ba6
BLAKE2b-256 24b54d5ca4cc8c02f38d87d98ea5e3d2306fc15400874d8b828dbf4ffd951e59

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page