Skip to main content

A Keras-native implementation of Kolmogorov-Arnold Networks (KANs) for TensorFlow.

Project description

TF-KAN: Kolmogorov-Arnold Networks for TensorFlow

A Keras-native, high-performance implementation of Kolmogorov-Arnold Networks (KANs) for TensorFlow 2.19+.

This library provides easy-to-use Keras layers that replace standard linear transformations with learnable B-spline activation functions, allowing for more expressive and interpretable models.


Key Features

  • 🧠 Learnable Activations: Goes beyond fixed activation functions like ReLU or SiLU by learning complex, data-driven activations on each weight.
  • 🧩 Seamless Keras Integration: Use DenseKAN and Conv*DKAN layers as direct, drop-in replacements for standard Keras layers.
  • ⚡ High Performance: Core mathematical operations are compiled into static graphs with @tf.function for maximum speed.
  • 🔄 Adaptive Grids: Dynamically update spline resolutions based on data, allowing the model to allocate its parameters more effectively.
  • 💾 Modern Serialization: Save and load models containing KAN layers with model.save() and tf.keras.models.load_model()—no custom_objects needed.

Installation

pip install tf-kan

Core Concepts

In a traditional neural network, a connection is a single weight (w). In a KAN, each connection is a learnable 1D function (a B-spline), like a smart dimmer switch that can apply a complex curve to the input signal.

You control these functions with two hyperparameters:

  • grid_size: The resolution of the function. A larger size allows for more complex, "wiggly" functions.
  • spline_order: The smoothness of the function. An order of 3 (cubic) is recommended for smooth curves.

Examples

Here are several examples demonstrating how to use tfkan for different tasks.

1. Basic Regression

This example builds a simple model to learn a 1D function, showcasing the DenseKAN layer.

import tensorflow as tf
import numpy as np
from tfkan.layers import DenseKAN

# 1. Generate some synthetic data for y = sin(pi*x)
x_train = np.linspace(-1, 1, 100)[:, np.newaxis]
y_train = np.sin(np.pi * x_train)

# 2. Build the KAN model
# A small model is enough to learn this simple function
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(1,)),
    DenseKAN(units=16, grid_size=8, spline_order=3, name='kan_layer_1'),
    DenseKAN(units=1, name='kan_output')
])

# 3. Compile and train
model.compile(optimizer='adam', loss='mean_squared_error')
print("--- Training a simple regressor ---")
model.fit(x_train, y_train, epochs=50, verbose=0)

# 4. Test the model
print("--- Prediction ---")
test_input = tf.constant([[0.5]]) # sin(pi * 0.5) = 1.0
prediction = model.predict(test_input)
print(f"Model prediction for input 0.5: {prediction[0][0]:.4f}")
model.summary()

2. Image Classification (Hybrid CNN)

Mix standard Keras layers with Conv2DKAN and DenseKAN to build a powerful hybrid classifier.

import tensorflow as tf
from tfkan.layers import Conv2DKAN, DenseKAN

# 1. Load a dataset (using dummy data here)
(x_train, y_train), _ = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0

# 2. Build the hybrid model
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(32, 32, 3)),

    # Standard Conv block
    tf.keras.layers.Conv2D(32, kernel_size=3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),

    # KAN Conv block with specific KAN arguments
    Conv2DKAN(
        filters=64,
        kernel_size=3,
        padding='same',
        name='kan_conv',
        kan_kwargs={'grid_size': 5, 'spline_order': 3}
    ),
    tf.keras.layers.GlobalAveragePooling2D(),

    # KAN Dense layers for final classification
    DenseKAN(units=128, grid_size=8, name='kan_dense'),
    tf.keras.layers.Dense(units=10, name='output_logits') # Standard output layer
])

# 3. Compile and train
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
print("\n--- Training a hybrid CNN for image classification ---")
# model.fit(x_train, y_train, epochs=1, batch_size=64) # Uncomment to train
model.summary()

3. Advanced Usage: Adaptive Grid Updates

KANs can dynamically update their internal grids to better fit the data distribution. This is useful for refining a pre-trained model.

import tensorflow as tf
import numpy as np
from tfkan.layers import DenseKAN

# 1. Create a model and a sample data batch
model = tf.keras.Sequential([DenseKAN(16, grid_size=5, name='my_kan_layer', input_shape=(32,))])
sample_data = np.random.randn(100, 32).astype('float32')

# 2. Get the KAN layer from the model
kan_layer = model.get_layer('my_kan_layer')
print(f"Initial grid size: {kan_layer.grid_size}")

# 3. Update the grid based on the sample data
# This re-calculates knot locations to better cover the data's features
print("Updating grid from samples...")
kan_layer.update_grid_from_samples(sample_data)
print("Grid updated successfully.")

# 4. You can also extend the grid to a higher resolution
print("Extending grid to a larger size...")
try:
    kan_layer.extend_grid_from_samples(sample_data, extend_grid_size=10)
    print(f"Grid extended successfully. New grid size: {kan_layer.grid_size}")
except Exception as e:
    print(f"Error during extension: {e}")

4. Time Series Forecasting

Use Conv1DKAN to find complex temporal patterns in sequential data.

import tensorflow as tf
from tfkan.layers import Conv1DKAN, DenseKAN

# 1. Define model parameters for a time series task
lookback_window = 20  # Number of past time steps to use as input
num_features = 5      # Number of features at each time step
num_classes = 3       # Number of output classes

# 2. Build a model for sequence classification
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(lookback_window, num_features)),
    
    # 1D KAN convolution to extract temporal features
    Conv1DKAN(
        filters=32,
        kernel_size=3,
        kan_kwargs={'grid_size': 8}
    ),
    tf.keras.layers.GlobalAveragePooling1D(),
    
    # Dense KAN layers for classification
    DenseKAN(64),
    tf.keras.layers.Dense(num_classes)
])

# 3. Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy')
print("\n--- Time Series Model ---")
model.summary()

Contributing

Contributions are welcome! Please feel free to submit a pull request or open an issue.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf_kan_latest-1.1.0.tar.gz (16.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tf_kan_latest-1.1.0-py3-none-any.whl (19.0 kB view details)

Uploaded Python 3

File details

Details for the file tf_kan_latest-1.1.0.tar.gz.

File metadata

  • Download URL: tf_kan_latest-1.1.0.tar.gz
  • Upload date:
  • Size: 16.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for tf_kan_latest-1.1.0.tar.gz
Algorithm Hash digest
SHA256 3a6bcba476601347f60fee1c0ba6c8152a51d818741ead9f16ba7face8702f7a
MD5 70cb7cf2714b7284ed442f5a030362dd
BLAKE2b-256 b53636f81be72d06cdb9055ec3ea6e2977bc6fdc8124cca5c29d21926b2f0400

See more details on using hashes here.

File details

Details for the file tf_kan_latest-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: tf_kan_latest-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 19.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for tf_kan_latest-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9a666e258a0ed9cfcfe6e72032c71d85464113ac90e4d44a6d6627db7b221599
MD5 13a1afdfcbde0b4b2e4ffe2bcd4abbcf
BLAKE2b-256 97547b2a0ed2e9678edcae8f9719ad92ea8da2953a5621b5a53679030cf990d4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page