Skip to main content

Awesome kan_gpt created by AdityaNG

Project description

KAN-GPT

PyPI - Downloads PyPI - Version codecov CI GitHub License

The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling

Install it from PyPI

pip install kan_gpt

Citation

If you find our work useful cite us!

@misc{GANESH2024KANGPT,
  author       = {Aditya Nalgunda Ganesh},
  title        = {KAN-GPT: The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling},
  year         = {2024},
  month        = {May},
  note         = {Release 1.0.0, 9th May 2024},
  url          = {https://github.com/AdityaNG/kan-gpt/}
}

Usage

Refer to the KAN_GPT.ipynb and kan_gpt/prompt.py for usage examples. The following is an outline of how to use the model:

from kan_gpt.model import GPT
from transformers import GPT2Tokenizer

model_config = GPT.get_default_config()
model_config.model_type = "gpt2"
model_config.vocab_size = 50257
model_config.block_size = 1024
model = GPT(model_config)

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

prompt = "Bangalore is often described as the "

prompt_encoded = tokenizer.encode(
  text=prompt, add_special_tokens=False
)

x = torch.tensor(prompt_encoded).unsqueeze(0)

model.eval()
y = model.generate(x, 50)  # sample 50 tokens

result = tokenizer.decode(y[0])

print(result)

# Bangalore is often described as the Silicon Valley of India.
# The city has witnessed rapid growth in the past two decades.....

Setup for Development

# Download Repo
git clone https://github.com/AdityaNG/kan-gpt
cd kan-gpt
git pull

# Download Dataset
python3 -m kan_gpt.download_dataset --dataset tinyshakespeare
python3 -m kan_gpt.download_dataset --dataset mnist
python3 -m kan_gpt.download_dataset --dataset webtext

# Install dependencies for development
pip install -r requirements.txt
pip install -e .

Train

Use the following dummy script to make sure everything is working as expected

WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset --device cpu --max_iters 200
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset --device cpu --max_iters 200

Then make use of the training script

python -m kan_gpt.train

Prompt

You can prompt the model to produce text as follows

python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model_path (checkpoint)

Results

We train and compare KAN-GPT with an equivalent MLP-GPT model on the Tiny Shakespeare dataset. We observe that the KAN-GPT performs slightly better than the MLP-GPT. We are looking into further experiments to dive deeper. The results are shown below:

Metrics
results_loss results_cross_entropy results_perplexity

TODOs

  • Integrate minGPT and pykan
  • Dataset downloading script for WebText
  • PyTorch Dataset parser for WebText
  • PyTorch Dataset parser for tinyshakespeare
  • Mini training POC for KAN-GPT
    • Integrate KAN training logic from KAN.train_kan
    • Train a dummy batch w/o any memory issues
  • Mini training POC for MLP-GPT
  • Train MLP-GPT on the webtext dataset as a baseline
  • Train KAN-GPT on the webtext dataset as a baseline
  • Metrics comparing KAN-GPT and MLP-GPT
  • Auto Save checkpoints
  • Auto Save checkpoints to W&B
  • Auto Download model weights from git / huggingface
  • W&B hyperparam sweep script
  • Script to load checkpoint in interactive mode
  • Reduce requrements.txt constraints
  • Define pydantic model for training and sweep args
  • Pruning the package, get rid of unused code
  • Training script to PyTorch Lighting
  • Documentation: mkdocs gh-deploy
  • Integrate with efficient-kan
  • Test Cases
    • KAN: Forward-Backward test
    • GPT: Forward-Backward test
    • KAN_GPT: Forward-Backward test
    • EFFICIENT_KAN: Forward-Backward test

Development

Read the CONTRIBUTING.md file.

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kan_gpt-1.2.0.tar.gz (173.3 kB view details)

Uploaded Source

Built Distribution

kan_gpt-1.2.0-py3-none-any.whl (64.5 kB view details)

Uploaded Python 3

File details

Details for the file kan_gpt-1.2.0.tar.gz.

File metadata

  • Download URL: kan_gpt-1.2.0.tar.gz
  • Upload date:
  • Size: 173.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for kan_gpt-1.2.0.tar.gz
Algorithm Hash digest
SHA256 0be2012882ade7c9a8f2026eb6a079c5a42ee16a32a6a736d3b339c309b96304
MD5 6514d32b10d195d34252a448eff01ec8
BLAKE2b-256 a90150ef7267bfe71c47b8f04baa7f7942e05920729abf6143ed08d691888152

See more details on using hashes here.

File details

Details for the file kan_gpt-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: kan_gpt-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 64.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for kan_gpt-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7789342c9272b331c75a3a810fb6ac83146c7d2d40694a7360c53280f5c0b110
MD5 e3a202546ef1f2fb2b270be90ab8a019
BLAKE2b-256 4ac9624ad00e9372de502849e31c217fcc1a0517de4344658321b37988cc0ec0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page