Skip to main content

CVKAN: Complex-Valued Kolmogorov-Arnold Networks

Project description

CVKAN: Complex-Valued Kolmogorov-Arnold Networks

Authors: Matthias Wolff, Florian Eilers, Xiaoyi Jiang
University of Münster, Department of Computer Science

Link to Paper: https://arxiv.org/abs/2502.02417


Abstract

In this work we propose $\mathbb{C}$ KAN, a complex-valued KAN, to join the intrinsic interpretability of KANs and the advantages of Complex-Valued Neural Networks (CVNNs). We show how to transfer a KAN and the necessary associated mechanisms into the complex domain. To confirm that $\mathbb{C}$ KAN meets expectations we conduct experiments on symbolic complex-valued function fitting and physically meaningful formulae as well as on a more realistic dataset from knot theory. Our proposed $\mathbb{C}$ KAN is more stable and performs on par or better than real-valued KANs while requiring less parameters and a shallower network architecture, making it more explainable.

<CVKAN Plot>


Table of Contents


How to use

See demo.py

Install

pip install cvkan

Imports

from cvkan import CVKANWrapper, train_kans, KANPlot
from cvkan.models.CVKAN import Norms
from cvkan.utils import create_complex_dataset, CSVDataset
from cvkan.utils.loss_functions import MSE, MAE

Create Dataset

# Generate dataset for f(z)=(z1^2 + z2^2)^2
f_squaresquare = lambda x: (x[:,0]**2 + x[:,1]**2)**2
# create dataset (this is a dictionary with keys 'train_input', 'train_label', 'test_input' and 'test_label', each
# containing a Tensor as value)
dataset = create_complex_dataset(f=f_squaresquare, n_var=2, ranges=[-1,1], train_num=5000, test_num=1000)
# convert dataset to CSVDataset object for easier handling later
dataset = CSVDataset(dataset, input_vars=["z1", "z2"], output_vars=["(z1^2 + z2^2)^2"], categorical_vars=[])

$\mathbb{C}$ KAN

# create CVKAN model. Note that this is CVKANWrapper, which is basically the same as CVKAN but with additional
# features for plotting later on
cvkan_model = CVKANWrapper(layers_hidden=[2,1,1], num_grids=8, use_norm=Norms.BatchNorm, grid_mins=-2, grid_maxs=2, csilu_type="complex_weight")



# train cvkan_model on dataset
results = train_kans(cvkan_model,  # model
           dataset,  # dataset
           loss_fn_backprop=MSE(),  # loss function to use for backpropagation
           loss_fns={"mse": MSE(), "mae": MAE()},  # loss function dictionary to evaluate the model on
           epochs=500,  # epochs to train for
           batch_size=1000,  # batch size for training
           kan_explainer=None,  # we could specify an explainer to make edge's transparency represent edge's relevance
           add_softmax_lastlayer=False,  # we don't need softmax after last layer (as we are doing regression)
           last_layer_output_real=False  # last layer should also have complex-valued output (regression)
           )
print("results of training: \n", results)

Plotting

# plot the model
kan_plotter = KANPlot(cvkan_model,
                      kan_explainer=None,
                      input_featurenames=dataset.input_varnames,
                      output_names=dataset.output_varnames,
                      complex_valued=True,
                      )
kan_plotter.plot_all()

In rare occasions the random initialization of the weights is suboptimal, which leads to the model not training correctly. If you do not end up with an image similar to the one displayed above or if the resulting Test MSE is way worse than 0.05, please run again.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cvkan-1.0.0.tar.gz (61.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cvkan-1.0.0-py3-none-any.whl (69.8 kB view details)

Uploaded Python 3

File details

Details for the file cvkan-1.0.0.tar.gz.

File metadata

  • Download URL: cvkan-1.0.0.tar.gz
  • Upload date:
  • Size: 61.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.12

File hashes

Hashes for cvkan-1.0.0.tar.gz
Algorithm Hash digest
SHA256 70b68d2d1c86f70b70c2de7203763897b257f92557861bd40bbb4426f6176bbf
MD5 ec7e6f0bbf62ff9f77e37e1ee582ef8e
BLAKE2b-256 601b9a8d159be38517ae7bedc8d172190986fdbd4737a8fad0bcea45f44f8006

See more details on using hashes here.

File details

Details for the file cvkan-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: cvkan-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 69.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.12

File hashes

Hashes for cvkan-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1e8a6cf69d288c89abbc69ff889efcc387f793363b398717230044679ccbf740
MD5 e40a3b5cf93bf28bb812a995e12934cc
BLAKE2b-256 de5c572c7d2a0850f9bbcea512a956938bd416763c1a818e6c257502f62b0738

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page