Skip to main content

KAN: Kolmogorov–Arnold Networks on Apple silicon with MLX.

Project description

KAN: Kolmogorov–Arnold Networks in MLX for Apple silicon

This code contains an example implementation of training a Kolmogorov–Arnold Network (KAN) on the MNIST dataset using the MLX framework. This example demonstrates how to configure and train the model using various command-line arguments for flexibility.

Based on the paper

Table of Contents

Installation

Usage with PyPi

install the Package:

pip install mlx-kan

Example usage in Python:

from kan_mlx.kan import KAN

# Initialize and use KAN
kan_model = KAN()
python -m mlx-kan.quick_scripts.quick_train --help
python -m mlx-kan.quick_scripts.quick_train --dataset fashion_mnist --num-layers 3 --hidden-dim 128 --num-epochs 20 --batch-size 128 --learning-rate 0.0005 --seed 42

Clone this Repo

To run this example, you need to have Python and the necessary libraries installed. Follow these steps to set up your environment:

  1. Clone the repository:
git clone https://github.com/Goekdeniz-Guelmez/mlx-kan.git
cd mlx-kan
  1. Install the required packages:
pip install -r requirements.txt

Usage

You can run the script main.py to train the KAN model on the MNIST dataset. The script supports various command-line arguments for configuration.

Arguments

  • --cpu: Use the Metal back-end.
  • --use-kan-convolution: Use the Convolution KAN architecture. Will give a error because its not implemented yet.
  • --dataset: The dataset to use (mnist or fashion_mnist). Default is mnist.
  • --num_layers: Number of layers in the model. Default is 2.
  • --in-features: Number input features. Default is 28.
  • --out-features: Number output features. Default is 28.
  • --num-classes: Number of output classes. Default is 10.
  • --hidden_dim: Number of hidden units in each layer. Default is 64.
  • --num_epochs: Number of epochs to train. Default is 10.
  • --batch_size: Batch size for training. Default is 64.
  • --learning_rate: Learning rate for the optimizer. Default is 1e-3.
  • --weight-decay: Weight decay for the optimizer. Default is 1e-4.
  • --eval-report-count: Number of epochs to report validations / test accuracy values. Default is 10.
  • --save-path: Path with the model name where the trained KAN model will be saved. Default is traned_kan_model.safetensors.
  • --train-batched: Use batch training instead of single epoch. Default is False.
  • --seed: Random seed for reproducibility. Default is 0.

Examples

Find all Arguments wioth descriptions

python -m quick_scripts.quick_train --help

Basic Usage

Train the KAN model on the MNIST dataset with default settings:

python -m quick_scripts.quick_train --dataset mnist

Custom Configuration

Train the KAN model with a custom configuration:

python -m quick_scripts.quick_train --dataset fashion_mnist --num-layers 3 --hidden-dim 128 --num-epochs 20 --batch-size 128 --learning-rate 0.0005 --seed 42

Using CPU

Train the KAN model using the CPU backend:

python -m quick_scripts.quick_train --cpu --dataset mnist

Model Architecture

The KAN (Kolmogorov–Arnold Networks) class defines the model architecture. The network consists of multiple KANLinear layers, each defined by the provided parameters. The number of layers and the hidden dimension size can be configured via command-line arguments.

Example Model Initialization

layers_hidden = [28 * 28] + [hidden_dim] * (num_layers - 1) + [num_classes]
model = KAN(layers_hidden)

KAN Class

The KAN class initializes a sequence of KANLinear layers based on the provided hidden layers configuration. Each layer performs linear transformations with kernel attention mechanisms.

class KAN(nn.Module):
    def __init__(self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1]):
        super().__init__()
        self.layers = []
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features, out_features, grid_size, spline_order, scale_noise, scale_base, scale_spline, base_activation, grid_eps, grid_range
                )
            )
    def __call__(self, x, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x
    
    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return mx.add(*(
            layer.regularization_loss(regularize_activation, regularize_entropy) 
            for layer in self.layers
        ))

KanConvolutional Class

The KanConvolutional class defines the convolutional model architecture. The network consists of multiple KANConv layers, each defined by the provided parameters. This class is used for models that require convolutional layers.

class KanConvolutional(nn.Module):
    def __init__(self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1]):
        super().__init__()
        self.layers = []
        for in_channels, out_channels in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANConv(
                    in_channels, out_channels, grid_size, spline_order, scale_noise, scale_base, scale_spline, base_activation, grid_eps, grid_range
                )
            )
    def __call__(self, x, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x
    
    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return mx.add(*(
            layer.regularization_loss(regularize_activation, regularize_entropy) 
            for layer in self.layers
        ))

Contributing

Contributions are welcome! If you have any suggestions or improvements, feel free to open an issue or submit a pull request.

  1. Fork the repository.
  2. Create a new branch (git checkout -b feature-branch).
  3. Make your changes.
  4. Commit your changes (git commit -m 'Add new feature').
  5. Push to the branch (git push origin feature-branch).
  6. Create a new Pull Request.

License

This project is licensed under the Apache 2.0 License. See the LICENSE file for details.

Made with love by Gökdeniz Gülmez.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_kan-0.1.71.tar.gz (13.8 kB view details)

Uploaded Source

Built Distribution

mlx_kan-0.1.71-py3-none-any.whl (14.8 kB view details)

Uploaded Python 3

File details

Details for the file mlx_kan-0.1.71.tar.gz.

File metadata

  • Download URL: mlx_kan-0.1.71.tar.gz
  • Upload date:
  • Size: 13.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for mlx_kan-0.1.71.tar.gz
Algorithm Hash digest
SHA256 1b2ea795786876fe3dc8986ac6fd1439b27fbadab9d6bc2c0dc29f9c44703c2e
MD5 7159c0f6ffbe5476e47d0e2107552152
BLAKE2b-256 1308ea168d99b4531660bb2316028be98b7579f3382ba462ad24526d71dc9952

See more details on using hashes here.

File details

Details for the file mlx_kan-0.1.71-py3-none-any.whl.

File metadata

  • Download URL: mlx_kan-0.1.71-py3-none-any.whl
  • Upload date:
  • Size: 14.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for mlx_kan-0.1.71-py3-none-any.whl
Algorithm Hash digest
SHA256 ec0ba0c4522da124842667770d3f38b359c37eec0cf7c66f4305b9a4c59a8c71
MD5 ca489cc71b75768e0872fd3c295c1539
BLAKE2b-256 5507d4c7e413ebf0f3adc6ee5856aca41f7ee3d66c96af3b2005598c70abde6d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page