A small-scale transformer-based language model implemented from scratch in Python.
Project description
ScratchGPT
ScratchGPT is a Python project that implements a small-scale transformer-based language model from scratch. It provides functionality for training the model on custom datasets and generating text based on prompts.
Features
- Custom transformer architecture implementation
- Training on user-provided text data
- Text generation using the trained model
- Flexible tokenization using TikToken
- Command-line interfaces for training and inference
Roadmap
- Switch to uv
- Make it easy to modify with a config file
- Extract the loss calculation from the model
- Rename main to train
- Create or check tokenizer interface
- Create an easy to use interface
- Make it into a package
- Apply SOTA optimizations
Requirements
- Python 3.12+
uvfor dependency management
Installation
-
Clone the repository:
git clone https://github.com/LabStrangeLoop/scratchgpt.git cd scratchgpt -
Install dependencies using uv:
uv sync --all-groups
Usage
Training
To train the model on your custom dataset:
uv run train -t <path_to_training_data> -e <experiment_folder>
-t, --train_source: Path to the training data file or folder-e, --experiment: Path to the folder where experiment checkpoints will be saved
Inference
To generate text using a trained model:
uv run infer -e <experiment_folder> [-d <device>] [-m <max_tokens>]
-e, --experiment: Path to the folder containing the trained model-d, --device: Device to run the model on (default: "cuda")-m, --max_tokens: Maximum number of tokens to generate (default: 512)
Tokenization
To explore the TikToken tokenizer:
uv run tiktoken
Project Structure
scratchgpt/train.py: Main training scriptscratchgpt/infer.py: Inference script for text generationscratchgpt/model_io.py: Utilities for saving and loading modelsscratchgpt/tokenizer/: Tokenizer implementations
Development
This project uses various development tools:
mypyfor static type checkingrufffor formatting and standard adherencepytestfor testing
Run the following commands to ensure code quality:
uv run ruff --fix .
uv run mypy .
uv run pytest
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
License
Authors
- Aleksandr Yeganov
- Dario Cazzani
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file scratchgpt-0.3.0.tar.gz.
File metadata
- Download URL: scratchgpt-0.3.0.tar.gz
- Upload date:
- Size: 426.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5cc84af874e6d6cfa5d2a582749116235718de50cf5d32831a97abeded5b8f5d
|
|
| MD5 |
3552cce4a36f94735aae6cb1b3f202a8
|
|
| BLAKE2b-256 |
d14885df80501ad8a81d2d9d8e76480fed279076e1076bc7ef98d845adbb91db
|
File details
Details for the file scratchgpt-0.3.0-py3-none-any.whl.
File metadata
- Download URL: scratchgpt-0.3.0-py3-none-any.whl
- Upload date:
- Size: 20.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f7431d679e16ae3bd46c8144ae1c662cdbde3b4f0738085666eaf3c01bd8be69
|
|
| MD5 |
80b0e3bc0cc8b22a7875e7ef4f0ce84b
|
|
| BLAKE2b-256 |
25960ca6814047a2ea08d6f7fed59c20c09ad3daab96148d8fa61806ee368a96
|