CVKAN: Complex-Valued Kolmogorov-Arnold Networks
Project description
CVKAN: Complex-Valued Kolmogorov-Arnold Networks
Authors: Matthias Wolff, Florian Eilers, Xiaoyi Jiang
University of Münster, Department of Computer Science
Link to Paper: https://arxiv.org/abs/2502.02417
Abstract
In this work we propose $\mathbb{C}$ KAN, a complex-valued KAN, to join the intrinsic interpretability of KANs and the advantages of Complex-Valued Neural Networks (CVNNs). We show how to transfer a KAN and the necessary associated mechanisms into the complex domain. To confirm that $\mathbb{C}$ KAN meets expectations we conduct experiments on symbolic complex-valued function fitting and physically meaningful formulae as well as on a more realistic dataset from knot theory. Our proposed $\mathbb{C}$ KAN is more stable and performs on par or better than real-valued KANs while requiring less parameters and a shallower network architecture, making it more explainable.
Table of Contents
- src/cvkan/experiments: Scripts for our experiments and corresponding results
- fit_formulas.py: Experiments for function fitting. Simple arbitrary $\left(z^2, \quad \sin(z), \quad z_1 \cdot z_2, \quad (z_1^2 + z_2^2)^2 \right)\quad$ as well as physically meaningful formulae (circuit & holography)
- knot_dataset.py: Experiments for knot classification
- results.json: All of our results as a list of dictionaries, stored as JSON
- run_crossval.py: Script to run k-fold cross-validation on a dataset and model given. Also stores the results with additional meta-data in a json file
- src/images: The images used in our paper
- visualizations.py: Script to create some of the images we used in our paper
- src/cvkan/models:
- functions: different helper functions ($\mathbb{C}$ SiLU, BatchNorms)
- CompleySilu.py: Two different variants of complex SiLU
- CV_LayerNorm.py: Different complex-valued BatchNorm approaches and LayerNorm
- wrapper: Folder contains Wrappers for every KAN to make them work with our KanPlotter and KanExplainer
- CVKANWrapper.py: Wrapper for our CVKAN
- PyKANWrapper.py: Wrapper for pyKAN
- WrapperTemplate.py: Template (Interface) for all specific KAN Wrappers
- CVKAN.py: $\mathbb{C}$ KAN model definition
- FastKAN.py: modified version of FastKAN model definition, originally from Github Repository ZiyaoLi/fast-kan
- functions: different helper functions ($\mathbb{C}$ SiLU, BatchNorms)
- src/cvkan/train/train_loop.py: Main loop for training all kinds of KANs on different datasets using custom loss functions
- src/cvkan/utils: miscellaneous utils
- dataloading: utils for dataloading
- create_complex_dataset.py: Create a complex-valued dataset dictionary based on a lambda expression as symbolic formula.
- crossval_splitter.py: Automatically create datasets for k-fold cross-validation
- csv_dataloader.py: Dataloader and Dataset-Class for a comma-seperated CSV file or dictionary
- latex: Utils to generate LaTeX outputs automatically
- latex_table_creator.py: Automatically generate resulting LaTeX tables from results.json
- plotting: utils for plotting
- cplot.py: Experiments with plotting standard complex-valued functions (i.e. $z^2$)
- cplotting_tools.py: modified version of FastKAN model definition, originally from Github Repository artmenlope/complex-plotting-tools
- plot_kan.py: Plot KAN models (real-valued as well as complex-valued) with interactive elements
- eval_model.py: Evaluation of models and plotting of confusion matrix
- explain_kan.py: Explain KAN models by calculating edge relevance scores in the same way as Ziming Liu's pyKAN 2.0
- json_editor.py: Manipulate the results.json file
- loss_functions.py: MSE, MAE and cross entropy loss-functions
- misc.py: Miscellaneous short scripts and methods
- dataloading: utils for dataloading
How to use
See demo.py
Install
pip install cvkan
Imports
from cvkan import CVKANWrapper, train_kans, KANPlot
from cvkan.models.CVKAN import Norms
from cvkan.utils import create_complex_dataset, CSVDataset
from cvkan.utils.loss_functions import MSE, MAE
Create Dataset
# Generate dataset for f(z)=(z1^2 + z2^2)^2
f_squaresquare = lambda x: (x[:,0]**2 + x[:,1]**2)**2
# create dataset (this is a dictionary with keys 'train_input', 'train_label', 'test_input' and 'test_label', each
# containing a Tensor as value)
dataset = create_complex_dataset(f=f_squaresquare, n_var=2, ranges=[-1,1], train_num=5000, test_num=1000)
# convert dataset to CSVDataset object for easier handling later
dataset = CSVDataset(dataset, input_vars=["z1", "z2"], output_vars=["(z1^2 + z2^2)^2"], categorical_vars=[])
$\mathbb{C}$ KAN
# create CVKAN model. Note that this is CVKANWrapper, which is basically the same as CVKAN but with additional
# features for plotting later on
cvkan_model = CVKANWrapper(layers_hidden=[2,1,1], num_grids=8, use_norm=Norms.BatchNorm, grid_mins=-2, grid_maxs=2, csilu_type="complex_weight")
# train cvkan_model on dataset
results = train_kans(cvkan_model, # model
dataset, # dataset
loss_fn_backprop=MSE(), # loss function to use for backpropagation
loss_fns={"mse": MSE(), "mae": MAE()}, # loss function dictionary to evaluate the model on
epochs=500, # epochs to train for
batch_size=1000, # batch size for training
kan_explainer=None, # we could specify an explainer to make edge's transparency represent edge's relevance
add_softmax_lastlayer=False, # we don't need softmax after last layer (as we are doing regression)
last_layer_output_real=False # last layer should also have complex-valued output (regression)
)
print("results of training: \n", results)
Plotting
# plot the model
kan_plotter = KANPlot(cvkan_model,
kan_explainer=None,
input_featurenames=dataset.input_varnames,
output_names=dataset.output_varnames,
complex_valued=True,
)
kan_plotter.plot_all()
In rare occasions the random initialization of the weights is suboptimal, which leads to the model not training correctly. If you do not end up with an image similar to the one displayed above or if the resulting Test MSE is way worse than 0.05, please run again.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cvkan-1.0.0.tar.gz.
File metadata
- Download URL: cvkan-1.0.0.tar.gz
- Upload date:
- Size: 61.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
70b68d2d1c86f70b70c2de7203763897b257f92557861bd40bbb4426f6176bbf
|
|
| MD5 |
ec7e6f0bbf62ff9f77e37e1ee582ef8e
|
|
| BLAKE2b-256 |
601b9a8d159be38517ae7bedc8d172190986fdbd4737a8fad0bcea45f44f8006
|
File details
Details for the file cvkan-1.0.0-py3-none-any.whl.
File metadata
- Download URL: cvkan-1.0.0-py3-none-any.whl
- Upload date:
- Size: 69.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1e8a6cf69d288c89abbc69ff889efcc387f793363b398717230044679ccbf740
|
|
| MD5 |
e40a3b5cf93bf28bb812a995e12934cc
|
|
| BLAKE2b-256 |
de5c572c7d2a0850f9bbcea512a956938bd416763c1a818e6c257502f62b0738
|