Skip to main content

Awesome kan_gpt created by AdityaNG

Project description

KAN-GPT

PyPI - Downloads PyPI - Version codecov CI GitHub License

The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling

Install it from PyPI

pip install kan_gpt

Citation

If you find our work useful cite us!

@misc{GANESH2024KANGPT,
  author       = {Aditya Nalgunda Ganesh},
  title        = {KAN-GPT: The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling},
  year         = {2024},
  month        = {May},
  note         = {Release 1.0.0, 9th May 2024},
  url          = {https://github.com/AdityaNG/kan-gpt/}
}

Usage

Refer to the KAN_GPT.ipynb and kan_gpt/prompt.py for usage examples. The following is an outline of how to use the model:

from kan_gpt.model import GPT
from transformers import GPT2Tokenizer

model_config = GPT.get_default_config()
model_config.model_type = "gpt2"
model_config.vocab_size = 50257
model_config.block_size = 1024
model = GPT(model_config)

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

prompt = "Bangalore is often described as the "

prompt_encoded = tokenizer.encode(
  text=prompt, add_special_tokens=False
)

x = torch.tensor(prompt_encoded).unsqueeze(0)

model.eval()
y = model.generate(x, 50)  # sample 50 tokens

result = tokenizer.decode(y)

print(result)

# Bangalore is often described as the Silicon Valley of India.
# The city has witnessed rapid growth in the past two decades.....

Setup for Development

# Download Repo
git clone https://github.com/AdityaNG/kan-gpt
cd kan-gpt
git pull

# Download Dataset
./scripts/download_webtext.sh
./scripts/download_tinyshakespeare.sh

# Install dependencies for development
pip install -r requirements.txt
pip install -e .

Train

Use the following dummy script to make sure everything is working as expected

WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset --device cpu --max_iters 200
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset --device cpu --max_iters 200

Then make use of the training script

python -m kan_gpt.train

Prompt

You can prompt the model to produce text as follows

python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model_path (checkpoint)

Results

We train and compare KAN-GPT with an equivalent MLP-GPT model on the Tiny Shakespeare dataset. We observe that the KAN-GPT performs slightly better than the MLP-GPT. We are looking into further experiments to dive deeper. The results are shown below:

Metrics
results_loss results_cross_entropy results_perplexity

TODOs

  • Integrate minGPT and pykan
  • Dataset downloading script for WebText
  • PyTorch Dataset parser for WebText
  • PyTorch Dataset parser for tinyshakespeare
  • Mini training POC for KAN-GPT
    • Integrate KAN training logic from KAN.train_kan
    • Train a dummy batch w/o any memory issues
  • Mini training POC for MLP-GPT
  • Train MLP-GPT on the webtext dataset as a baseline
  • Train KAN-GPT on the webtext dataset as a baseline
  • Metrics comparing KAN-GPT and MLP-GPT
  • Auto Save checkpoints
  • Auto Save checkpoints to W&B
  • Auto Download model weights from git / huggingface
  • W&B hyperparam sweep script
  • Script to load checkpoint in interactive mode
  • Reduce requrements.txt constraints
  • Define pydantic model for training and sweep args
  • Pruning the package, get rid of unused code
  • Training script to PyTorch Lighting
  • Documentation: mkdocs gh-deploy
  • Integrate with efficient-kan
  • Test Cases
    • KAN: Forward-Backward test
    • GPT: Forward-Backward test
    • KAN_GPT: Forward-Backward test
    • EFFICIENT_KAN: Forward-Backward test

Development

Read the CONTRIBUTING.md file.

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kan_gpt-1.0.5.tar.gz (172.1 kB view hashes)

Uploaded Source

Built Distribution

kan_gpt-1.0.5-py3-none-any.whl (63.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page