Skip to main content

Convolutional KAN layer

Project description

Convolutional KAN layer

Implementation of the Convolutional Kolmogorov-Arnold Network layer in PyTorch.

A drop-in replacement for the torch.nn.Conv2d layer that uses the Kolmogorov-Arnold Network (KAN) instead of the standard convolution.

Currently, supports grouped convolution, padding with different modes, dilation, and stride.

The KAN implementation is taken from the https://github.com/Blealtan/efficient-kan/ repository.

Installation

From PyPI:

pip install convkan

From source:

git clone git@github.com:StarostinV/convkan.git
cd convkan
pip install .

Usage

Training a simple model on MNIST (96% accuracy after the first epoch):

import torch
from torch import nn
from torchvision import datasets, transforms
from tqdm import tqdm

from convkan import ConvKAN, LayerNorm2D

# Define the model
model = nn.Sequential(
    ConvKAN(1, 32, padding=1, kernel_size=3, stride=1),
    LayerNorm2D(32),
    ConvKAN(32, 32, padding=1, kernel_size=3, stride=2),
    LayerNorm2D(32),
    ConvKAN(32, 10, padding=1, kernel_size=3, stride=2),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
).cuda()

# Define transformations and download the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)

# Train the model
model.train()

pbar = tqdm(train_loader)
for i, (x, y) in enumerate(pbar):
    x, y = x.cuda(), y.cuda()
    optimizer.zero_grad()
    y_hat = model(x)
    loss = criterion(y_hat, y)
    loss.backward()
    optimizer.step()
    pbar.set_description(f'Loss: {loss.item():.2e}')

model.eval()
correct = 0
total = 0

with torch.no_grad():
    pbar = tqdm(test_loader)
    for x, y in pbar:
        x, y = x.cuda(), y.cuda()
        y_hat = model(x)
        _, predicted = torch.max(y_hat, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()
        pbar.set_description(f'Accuracy: {100 * correct / total:.2f}%')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

convkan-0.0.1.2.tar.gz (12.2 kB view details)

Uploaded Source

File details

Details for the file convkan-0.0.1.2.tar.gz.

File metadata

  • Download URL: convkan-0.0.1.2.tar.gz
  • Upload date:
  • Size: 12.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.9

File hashes

Hashes for convkan-0.0.1.2.tar.gz
Algorithm Hash digest
SHA256 2f7384fd4d8861385d58273c42a0a0b16bfc376cb167e085f4a3ca4bd3e35484
MD5 a3aadcc53c4c9e59c104867ac1a325a5
BLAKE2b-256 3c9cb30dab9f182d5b3768f8a73ac1adfb551c0f30e1f629c091213aa26b8ef8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page