KAN: Kolmogorov–Arnold Networks on Apple silicon with MLX.
Project description
KAN: Kolmogorov–Arnold Networks in MLX for Apple silicon
This code contains an example implementation of training a Kolmogorov–Arnold Network (KAN) on the MNIST dataset using the MLX framework. This example demonstrates how to configure and train the model using various command-line arguments for flexibility.
Based on the paper
Table of Contents
Installation
Usage with PyPi
install the Package:
pip install mlx-kan
Example usage in Python:
from kan_mlx.kan import KAN
# Initialize and use KAN
kan_model = KAN()
python -m mlx-kan.quick_scripts.quick_train --help
python -m mlx-kan.quick_scripts.quick_train --dataset fashion_mnist --num-layers 3 --hidden-dim 128 --num-epochs 20 --batch-size 128 --learning-rate 0.0005 --seed 42
Clone this Repo
To run this example, you need to have Python and the necessary libraries installed. Follow these steps to set up your environment:
- Clone the repository:
git clone https://github.com/Goekdeniz-Guelmez/mlx-kan.git
cd mlx-kan
- Install the required packages:
pip install -r requirements.txt
Usage
You can run the script main.py
to train the KAN model on the MNIST dataset. The script supports various command-line arguments for configuration.
Arguments
--cpu
: Use the Metal back-end.--use-kan-convolution
: Use the Convolution KAN architecture. Will give a error because its not implemented yet.--dataset
: The dataset to use (mnist
orfashion_mnist
). Default ismnist
.--num_layers
: Number of layers in the model. Default is2
.--in-features
: Number input features. Default is28
.--out-features
: Number output features. Default is28
.--num-classes
: Number of output classes. Default is10
.--hidden_dim
: Number of hidden units in each layer. Default is64
.--num_epochs
: Number of epochs to train. Default is10
.--batch_size
: Batch size for training. Default is64
.--learning_rate
: Learning rate for the optimizer. Default is1e-3
.--weight-decay
: Weight decay for the optimizer. Default is1e-4
.--eval-report-count
: Number of epochs to report validations / test accuracy values. Default is10
.--save-path
: Path with the model name where the trained KAN model will be saved. Default istraned_kan_model.safetensors
.--train-batched
: Use batch training instead of single epoch. Default isFalse
.--seed
: Random seed for reproducibility. Default is0
.
Examples
Find all Arguments wioth descriptions
python -m quick_scripts.quick_train --help
Basic Usage
Train the KAN model on the MNIST dataset with default settings:
python -m quick_scripts.quick_train --dataset mnist
Custom Configuration
Train the KAN model with a custom configuration:
python -m quick_scripts.quick_train --dataset fashion_mnist --num-layers 3 --hidden-dim 128 --num-epochs 20 --batch-size 128 --learning-rate 0.0005 --seed 42
Using CPU
Train the KAN model using the CPU backend:
python -m quick_scripts.quick_train --cpu --dataset mnist
Model Architecture
The KAN
(Kolmogorov–Arnold Networks) class defines the model architecture. The network consists of multiple KANLinear
layers, each defined by the provided parameters. The number of layers and the hidden dimension size can be configured via command-line arguments.
Example Model Initialization
layers_hidden = [28 * 28] + [hidden_dim] * (num_layers - 1) + [num_classes]
model = KAN(layers_hidden)
KAN Class
The KAN
class initializes a sequence of KANLinear
layers based on the provided hidden layers configuration. Each layer performs linear transformations with kernel attention mechanisms.
class KAN(nn.Module):
def __init__(self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1]):
super().__init__()
self.layers = []
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
KANLinear(
in_features, out_features, grid_size, spline_order, scale_noise, scale_base, scale_spline, base_activation, grid_eps, grid_range
)
)
def __call__(self, x, update_grid=False):
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
return mx.add(*(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
))
KanConvolutional Class
The KanConvolutional class defines the convolutional model architecture. The network consists of multiple KANConv layers, each defined by the provided parameters. This class is used for models that require convolutional layers.
class KanConvolutional(nn.Module):
def __init__(self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1]):
super().__init__()
self.layers = []
for in_channels, out_channels in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
KANConv(
in_channels, out_channels, grid_size, spline_order, scale_noise, scale_base, scale_spline, base_activation, grid_eps, grid_range
)
)
def __call__(self, x, update_grid=False):
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
return mx.add(*(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
))
Contributing
Contributions are welcome! If you have any suggestions or improvements, feel free to open an issue or submit a pull request.
- Fork the repository.
- Create a new branch (
git checkout -b feature-branch
). - Make your changes.
- Commit your changes (
git commit -m 'Add new feature'
). - Push to the branch (
git push origin feature-branch
). - Create a new Pull Request.
License
This project is licensed under the Apache 2.0 License. See the LICENSE file for details.
Made with love by Gökdeniz Gülmez.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mlx_kan-0.1.71.tar.gz
.
File metadata
- Download URL: mlx_kan-0.1.71.tar.gz
- Upload date:
- Size: 13.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1b2ea795786876fe3dc8986ac6fd1439b27fbadab9d6bc2c0dc29f9c44703c2e |
|
MD5 | 7159c0f6ffbe5476e47d0e2107552152 |
|
BLAKE2b-256 | 1308ea168d99b4531660bb2316028be98b7579f3382ba462ad24526d71dc9952 |
File details
Details for the file mlx_kan-0.1.71-py3-none-any.whl
.
File metadata
- Download URL: mlx_kan-0.1.71-py3-none-any.whl
- Upload date:
- Size: 14.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ec0ba0c4522da124842667770d3f38b359c37eec0cf7c66f4305b9a4c59a8c71 |
|
MD5 | ca489cc71b75768e0872fd3c295c1539 |
|
BLAKE2b-256 | 5507d4c7e413ebf0f3adc6ee5856aca41f7ee3d66c96af3b2005598c70abde6d |