A fractal regularization technique for neural networks
Project description
Fractal Regularization
A PyTorch-based implementation of fractal regularization for neural networks. This technique leverages box-counting and multi-resolution analysis to provide a novel approach to network regularization, potentially improving model generalization and performance.
Installation
Install the package using pip:
pip install fractal-regularization
Features
- Box-counting based fractal dimension analysis
- Multi-resolution regularization
- Learnable L1 regularization scaling
- Dynamic adaptation based on training progress
- Compatible with PyTorch neural networks
- Easy integration with existing training loops
Usage
Here's a complete example of how to use the fractal regularization with a neural network:
from fractal_reg import ComplexFractalRegularizationLoss
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler
# Define your neural network
class ComplexNN(nn.Module):
def __init__(self):
super(ComplexNN, self).__init__()
self.fc1 = nn.Linear(8, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 64)
self.fc4 = nn.Linear(64, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = self.fc4(x)
return x
# Initialize the loss function
fractal_loss = ComplexFractalRegularizationLoss(
alpha=0.1, # Scaling factor for fractal loss
beta=0.01, # Scaling factor for multi-resolution loss
lambda_l1_init=0.1, # Initial value for learnable L1 scaling
resolution_factor=2 # Controls depth of multi-resolution analysis
)
# Training setup
model = ComplexNN()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(250):
# Forward pass
outputs = model(train_inputs)
# Calculate main loss
mse_loss = criterion(outputs, train_targets)
# Add fractal regularization
fractal_regularization = fractal_loss(model, epoch, 250)
# Combined loss
total_loss = mse_loss + fractal_regularization
# Backward pass
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
Parameters
The ComplexFractalRegularizationLoss class accepts the following parameters:
alpha(float, default=0.1): Scaling factor for the fractal loss componentbeta(float, default=0.01): Scaling factor for the multi-resolution loss componentlambda_l1_init(float, default=0.1): Initial value for the learnable L1 scaling factorresolution_factor(int, default=2): Controls the depth of multi-resolution analysis
Example Results
When applied to a regression task using the California Housing dataset, the fractal regularization demonstrates improved model generalization compared to standard approaches:
# Plot training curves
plt.plot(train_losses['Fractal'], label="Training Loss (Fractal)")
plt.plot(val_losses['Fractal'], label="Validation Loss (Fractal)")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss with Fractal Regularization")
plt.legend()
plt.show()
Requirements
- Python >= 3.6
- PyTorch >= 1.9.0
- NumPy >= 1.19.0
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fractal_regularization-0.1.1.tar.gz.
File metadata
- Download URL: fractal_regularization-0.1.1.tar.gz
- Upload date:
- Size: 4.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
904bdd551a92eb8988e874583fcc5cbddaad2465bcc36f67bc3f640e5ad25e9b
|
|
| MD5 |
79165d0a0b552acd234b8995b0088800
|
|
| BLAKE2b-256 |
e0b73c567fe23537f2279a5f8241539beea2b06c5f61f61449a353c5ed98a9bd
|
File details
Details for the file fractal_regularization-0.1.1-py3-none-any.whl.
File metadata
- Download URL: fractal_regularization-0.1.1-py3-none-any.whl
- Upload date:
- Size: 5.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8279968beb55269fa732d4ce948c2a292eca8284f2e9df2a3c685603d62ae69a
|
|
| MD5 |
7c1ab8f24dc43e44f6e6d3842ef1c248
|
|
| BLAKE2b-256 |
5171a457eca13eeb26496b3b716c41519cbaf2bed946b4307afd33d3033d7093
|