Skip to main content

Awesome kan_gpt created by AdityaNG

Project description

KAN-GPT

codecov CI

Awesome KAN-GPT created by AdityaNG

Install it from PyPI

pip install kan_gpt

Usage

from kan_gpt.model import GPT

model_config = GPT.get_default_config()
model_config.model_type = "gpt2"
model_config.vocab_size = 5
model_config.block_size = 10
model = GPT(model_config)

x = torch.zeros((1, 10), dtype=torch.long)
y = torch.zeros((1, 10), dtype=torch.long)

# x = x.cuda()
# y = y.cuda()
# model = model.cuda()

logits, loss = model(x, y)

print(logits.shape)
$ python -m kan_gpt.train

Setup

# Download Repo
%cd /content
!git clone https://github.com/AdityaNG/kan-gpt
%cd kan-gpt
!git pull

# Download Dataset
!./scripts/download_webtext.sh

# Install dependencies for development
!pip install -r requirements.txt
!pip install -e .

Train

Dummy script to make sure everything is working as expected

CUDA_VISIBLE_DEVICE="0" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset

TODOs

  • Integrate minGPT and pykan
  • Dataset downloading script for WebText
  • PyTorch Dataset parser for WebText
  • Mini training POC for KAN-GPT
    • Integrate KAN training logic from KAN.train_kan
  • Mini training POC for MLP-GPT
  • Train MLP-GPT on the webtext dataset as a baseline
  • Auto Save checkpoints
  • Auto Save checkpoints to W&B
  • Script to load checkpoint in interactive mode
  • Training script to PyTorch Lighting

Development

Read the CONTRIBUTING.md file.

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kan_gpt-0.1.0.tar.gz (156.0 kB view hashes)

Uploaded Source

Built Distribution

kan_gpt-0.1.0-py3-none-any.whl (53.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page