Skip to main content

A small-scale transformer-based language model implemented from scratch in Python.

Project description

ScratchGPT

ScratchGPT

ScratchGPT is a Python project that implements a small-scale transformer-based language model from scratch. It provides functionality for training the model on custom datasets and generating text based on prompts.

Features

  • Custom transformer architecture implementation
  • Training on user-provided text data
  • Text generation using the trained model
  • Flexible tokenization using TikToken
  • Command-line interfaces for training and inference

Roadmap

  • Switch to uv
  • Make it easy to modify with a config file
  • Extract the loss calculation from the model
  • Rename main to train
  • Create or check tokenizer interface
  • Create an easy to use interface
  • Make it into a package
  • Apply SOTA optimizations

Requirements

  • Python 3.12+
  • uv for dependency management

Installation

  1. Clone the repository:

    git clone https://github.com/LabStrangeLoop/scratchgpt.git
    cd scratchgpt
    
  2. Install dependencies using uv:

    uv sync --all-groups
    

Usage

Training

To train the model on your custom dataset:

uv run train -t <path_to_training_data> -e <experiment_folder>
  • -t, --train_source: Path to the training data file or folder
  • -e, --experiment: Path to the folder where experiment checkpoints will be saved

Inference

To generate text using a trained model:

uv run infer -e <experiment_folder> [-d <device>] [-m <max_tokens>]
  • -e, --experiment: Path to the folder containing the trained model
  • -d, --device: Device to run the model on (default: "cuda")
  • -m, --max_tokens: Maximum number of tokens to generate (default: 512)

Tokenization

To explore the TikToken tokenizer:

uv run tiktoken

Project Structure

  • scratchgpt/train.py: Main training script
  • scratchgpt/infer.py: Inference script for text generation
  • scratchgpt/model_io.py: Utilities for saving and loading models
  • scratchgpt/tokenizer/: Tokenizer implementations

Development

This project uses various development tools:

  • mypy for static type checking
  • ruff for formatting and standard adherence
  • pytest for testing

Run the following commands to ensure code quality:

uv run ruff --fix .
uv run mypy .
uv run pytest

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

License

MIT License

Authors

  • Aleksandr Yeganov
  • Dario Cazzani

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scratchgpt-0.3.0.tar.gz (426.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

scratchgpt-0.3.0-py3-none-any.whl (20.0 kB view details)

Uploaded Python 3

File details

Details for the file scratchgpt-0.3.0.tar.gz.

File metadata

  • Download URL: scratchgpt-0.3.0.tar.gz
  • Upload date:
  • Size: 426.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.14

File hashes

Hashes for scratchgpt-0.3.0.tar.gz
Algorithm Hash digest
SHA256 5cc84af874e6d6cfa5d2a582749116235718de50cf5d32831a97abeded5b8f5d
MD5 3552cce4a36f94735aae6cb1b3f202a8
BLAKE2b-256 d14885df80501ad8a81d2d9d8e76480fed279076e1076bc7ef98d845adbb91db

See more details on using hashes here.

File details

Details for the file scratchgpt-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: scratchgpt-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 20.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.14

File hashes

Hashes for scratchgpt-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f7431d679e16ae3bd46c8144ae1c662cdbde3b4f0738085666eaf3c01bd8be69
MD5 80b0e3bc0cc8b22a7875e7ef4f0ce84b
BLAKE2b-256 25960ca6814047a2ea08d6f7fed59c20c09ad3daab96148d8fa61806ee368a96

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page