Awesome kan_gpt created by AdityaNG
Project description
KAN-GPT
The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling
Install it from PyPI
pip install kan_gpt
Citation
If you find our work useful cite us!
@misc{GANESH2024KANGPT,
author = {Aditya Nalgunda Ganesh},
title = {KAN-GPT: The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling},
year = {2024},
month = {May},
note = {Release 1.0.0, 9th May 2024},
url = {https://github.com/AdityaNG/kan-gpt/}
}
Usage
Refer to the KAN_GPT.ipynb and kan_gpt/prompt.py for usage examples. The following is an outline of how to use the model:
from kan_gpt.model import GPT
from transformers import GPT2Tokenizer
model_config = GPT.get_default_config()
model_config.model_type = "gpt2"
model_config.vocab_size = 50257
model_config.block_size = 1024
model = GPT(model_config)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
prompt = "Bangalore is often described as the "
prompt_encoded = tokenizer.encode(
text=prompt, add_special_tokens=False
)
x = torch.tensor(prompt_encoded).unsqueeze(0)
model.eval()
y = model.generate(x, 50) # sample 50 tokens
result = tokenizer.decode(y[0])
print(result)
# Bangalore is often described as the Silicon Valley of India.
# The city has witnessed rapid growth in the past two decades.....
Setup for Development
# Download Repo
git clone https://github.com/AdityaNG/kan-gpt
cd kan-gpt
git pull
# Download Dataset
python3 -m kan_gpt.download_dataset --dataset tinyshakespeare
python3 -m kan_gpt.download_dataset --dataset mnist
python3 -m kan_gpt.download_dataset --dataset webtext
# Install dependencies for development
pip install -r requirements.txt
pip install -e .
Train
Use the following dummy script to make sure everything is working as expected
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset --device cpu --max_iters 200
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset --device cpu --max_iters 200
Then make use of the training script
python -m kan_gpt.train
Prompt
You can prompt the model to produce text as follows
python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model_path (checkpoint)
Results
We train and compare KAN-GPT with an equivalent MLP-GPT model on the Tiny Shakespeare dataset. We observe that the KAN-GPT performs slightly better than the MLP-GPT. We are looking into further experiments to dive deeper. The results are shown below:
Metrics | ||
---|---|---|
TODOs
- Integrate minGPT and pykan
- Dataset downloading script for WebText
- PyTorch Dataset parser for WebText
- PyTorch Dataset parser for tinyshakespeare
- Mini training POC for KAN-GPT
- Integrate KAN training logic from
KAN.train_kan
- Train a dummy batch w/o any memory issues
- Integrate KAN training logic from
- Mini training POC for MLP-GPT
- Train MLP-GPT on the webtext dataset as a baseline
- Train KAN-GPT on the webtext dataset as a baseline
- Metrics comparing KAN-GPT and MLP-GPT
- Auto Save checkpoints
- Auto Save checkpoints to W&B
- Auto Download model weights from git / huggingface
- W&B hyperparam sweep script
- Script to load checkpoint in interactive mode
- Reduce requrements.txt constraints
- Define pydantic model for training and sweep args
- Pruning the package, get rid of unused code
- Training script to PyTorch Lighting
- Documentation:
mkdocs gh-deploy
- Integrate with efficient-kan
- Test Cases
- KAN: Forward-Backward test
- GPT: Forward-Backward test
- KAN_GPT: Forward-Backward test
- EFFICIENT_KAN: Forward-Backward test
Development
Read the CONTRIBUTING.md file.
References
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file kan_gpt-1.2.0.tar.gz
.
File metadata
- Download URL: kan_gpt-1.2.0.tar.gz
- Upload date:
- Size: 173.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0be2012882ade7c9a8f2026eb6a079c5a42ee16a32a6a736d3b339c309b96304 |
|
MD5 | 6514d32b10d195d34252a448eff01ec8 |
|
BLAKE2b-256 | a90150ef7267bfe71c47b8f04baa7f7942e05920729abf6143ed08d691888152 |
File details
Details for the file kan_gpt-1.2.0-py3-none-any.whl
.
File metadata
- Download URL: kan_gpt-1.2.0-py3-none-any.whl
- Upload date:
- Size: 64.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7789342c9272b331c75a3a810fb6ac83146c7d2d40694a7360c53280f5c0b110 |
|
MD5 | e3a202546ef1f2fb2b270be90ab8a019 |
|
BLAKE2b-256 | 4ac9624ad00e9372de502849e31c217fcc1a0517de4344658321b37988cc0ec0 |