A Keras-native implementation of Kolmogorov-Arnold Networks (KANs) for TensorFlow.
Project description
TF-KAN: Kolmogorov-Arnold Networks for TensorFlow
A Keras-native, high-performance implementation of Kolmogorov-Arnold Networks (KANs) for TensorFlow 2.19+.
This library provides easy-to-use Keras layers that replace standard linear transformations with learnable B-spline activation functions, allowing for more expressive and interpretable models.
Key Features
- 🧠 Learnable Activations: Goes beyond fixed activation functions like ReLU or SiLU by learning complex, data-driven activations on each weight.
- 🧩 Seamless Keras Integration: Use
DenseKANandConv*DKANlayers as direct, drop-in replacements for standard Keras layers. - ⚡ High Performance: Core mathematical operations are compiled into static graphs with
@tf.functionfor maximum speed. - 🔄 Adaptive Grids: Dynamically update spline resolutions based on data, allowing the model to allocate its parameters more effectively.
- 💾 Modern Serialization: Save and load models containing KAN layers with
model.save()andtf.keras.models.load_model()—nocustom_objectsneeded.
Installation
pip install tf-kan
Core Concepts
In a traditional neural network, a connection is a single weight (w). In a KAN, each connection is a learnable 1D function (a B-spline), like a smart dimmer switch that can apply a complex curve to the input signal.
You control these functions with two hyperparameters:
grid_size: The resolution of the function. A larger size allows for more complex, "wiggly" functions.spline_order: The smoothness of the function. An order of 3 (cubic) is recommended for smooth curves.
Examples
Here are several examples demonstrating how to use tfkan for different tasks.
1. Basic Regression
This example builds a simple model to learn a 1D function, showcasing the DenseKAN layer.
import tensorflow as tf
import numpy as np
from tfkan.layers import DenseKAN
# 1. Generate some synthetic data for y = sin(pi*x)
x_train = np.linspace(-1, 1, 100)[:, np.newaxis]
y_train = np.sin(np.pi * x_train)
# 2. Build the KAN model
# A small model is enough to learn this simple function
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(1,)),
DenseKAN(units=16, grid_size=8, spline_order=3, name='kan_layer_1'),
DenseKAN(units=1, name='kan_output')
])
# 3. Compile and train
model.compile(optimizer='adam', loss='mean_squared_error')
print("--- Training a simple regressor ---")
model.fit(x_train, y_train, epochs=50, verbose=0)
# 4. Test the model
print("--- Prediction ---")
test_input = tf.constant([[0.5]]) # sin(pi * 0.5) = 1.0
prediction = model.predict(test_input)
print(f"Model prediction for input 0.5: {prediction[0][0]:.4f}")
model.summary()
2. Image Classification (Hybrid CNN)
Mix standard Keras layers with Conv2DKAN and DenseKAN to build a powerful hybrid classifier.
import tensorflow as tf
from tfkan.layers import Conv2DKAN, DenseKAN
# 1. Load a dataset (using dummy data here)
(x_train, y_train), _ = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
# 2. Build the hybrid model
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(32, 32, 3)),
# Standard Conv block
tf.keras.layers.Conv2D(32, kernel_size=3, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(),
# KAN Conv block with specific KAN arguments
Conv2DKAN(
filters=64,
kernel_size=3,
padding='same',
name='kan_conv',
kan_kwargs={'grid_size': 5, 'spline_order': 3}
),
tf.keras.layers.GlobalAveragePooling2D(),
# KAN Dense layers for final classification
DenseKAN(units=128, grid_size=8, name='kan_dense'),
tf.keras.layers.Dense(units=10, name='output_logits') # Standard output layer
])
# 3. Compile and train
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
print("\n--- Training a hybrid CNN for image classification ---")
# model.fit(x_train, y_train, epochs=1, batch_size=64) # Uncomment to train
model.summary()
3. Advanced Usage: Adaptive Grid Updates
KANs can dynamically update their internal grids to better fit the data distribution. This is useful for refining a pre-trained model.
import tensorflow as tf
import numpy as np
from tfkan.layers import DenseKAN
# 1. Create a model and a sample data batch
model = tf.keras.Sequential([DenseKAN(16, grid_size=5, name='my_kan_layer', input_shape=(32,))])
sample_data = np.random.randn(100, 32).astype('float32')
# 2. Get the KAN layer from the model
kan_layer = model.get_layer('my_kan_layer')
print(f"Initial grid size: {kan_layer.grid_size}")
# 3. Update the grid based on the sample data
# This re-calculates knot locations to better cover the data's features
print("Updating grid from samples...")
kan_layer.update_grid_from_samples(sample_data)
print("Grid updated successfully.")
# 4. You can also extend the grid to a higher resolution
print("Extending grid to a larger size...")
try:
kan_layer.extend_grid_from_samples(sample_data, extend_grid_size=10)
print(f"Grid extended successfully. New grid size: {kan_layer.grid_size}")
except Exception as e:
print(f"Error during extension: {e}")
4. Time Series Forecasting
Use Conv1DKAN to find complex temporal patterns in sequential data.
import tensorflow as tf
from tfkan.layers import Conv1DKAN, DenseKAN
# 1. Define model parameters for a time series task
lookback_window = 20 # Number of past time steps to use as input
num_features = 5 # Number of features at each time step
num_classes = 3 # Number of output classes
# 2. Build a model for sequence classification
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(lookback_window, num_features)),
# 1D KAN convolution to extract temporal features
Conv1DKAN(
filters=32,
kernel_size=3,
kan_kwargs={'grid_size': 8}
),
tf.keras.layers.GlobalAveragePooling1D(),
# Dense KAN layers for classification
DenseKAN(64),
tf.keras.layers.Dense(num_classes)
])
# 3. Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy')
print("\n--- Time Series Model ---")
model.summary()
Contributing
Contributions are welcome! Please feel free to submit a pull request or open an issue.
License
This project is licensed under the MIT License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tf_kan_latest-1.1.0.tar.gz.
File metadata
- Download URL: tf_kan_latest-1.1.0.tar.gz
- Upload date:
- Size: 16.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3a6bcba476601347f60fee1c0ba6c8152a51d818741ead9f16ba7face8702f7a
|
|
| MD5 |
70cb7cf2714b7284ed442f5a030362dd
|
|
| BLAKE2b-256 |
b53636f81be72d06cdb9055ec3ea6e2977bc6fdc8124cca5c29d21926b2f0400
|
File details
Details for the file tf_kan_latest-1.1.0-py3-none-any.whl.
File metadata
- Download URL: tf_kan_latest-1.1.0-py3-none-any.whl
- Upload date:
- Size: 19.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9a666e258a0ed9cfcfe6e72032c71d85464113ac90e4d44a6d6627db7b221599
|
|
| MD5 |
13a1afdfcbde0b4b2e4ffe2bcd4abbcf
|
|
| BLAKE2b-256 |
97547b2a0ed2e9678edcae8f9719ad92ea8da2953a5621b5a53679030cf990d4
|