Skip to main content

A fractal regularization technique for neural networks

Project description

Fractal Regularization

A PyTorch-based implementation of fractal regularization for neural networks. This technique leverages box-counting and multi-resolution analysis to provide a novel approach to network regularization, potentially improving model generalization and performance.

Installation

Install the package using pip:

pip install fractal-regularization

Features

  • Box-counting based fractal dimension analysis
  • Multi-resolution regularization
  • Learnable L1 regularization scaling
  • Dynamic adaptation based on training progress
  • Compatible with PyTorch neural networks
  • Easy integration with existing training loops

Usage

Here's a complete example of how to use the fractal regularization with a neural network:

from fractal_reg import ComplexFractalRegularizationLoss
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler

# Define your neural network
class ComplexNN(nn.Module):
    def __init__(self):
        super(ComplexNN, self).__init__()
        self.fc1 = nn.Linear(8, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

# Initialize the loss function
fractal_loss = ComplexFractalRegularizationLoss(
    alpha=0.1,      # Scaling factor for fractal loss
    beta=0.01,      # Scaling factor for multi-resolution loss
    lambda_l1_init=0.1,  # Initial value for learnable L1 scaling
    resolution_factor=2   # Controls depth of multi-resolution analysis
)

# Training setup
model = ComplexNN()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(250):
    # Forward pass
    outputs = model(train_inputs)
    
    # Calculate main loss
    mse_loss = criterion(outputs, train_targets)
    
    # Add fractal regularization
    fractal_regularization = fractal_loss(model, epoch, 250)
    
    # Combined loss
    total_loss = mse_loss + fractal_regularization
    
    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

Parameters

The ComplexFractalRegularizationLoss class accepts the following parameters:

  • alpha (float, default=0.1): Scaling factor for the fractal loss component
  • beta (float, default=0.01): Scaling factor for the multi-resolution loss component
  • lambda_l1_init (float, default=0.1): Initial value for the learnable L1 scaling factor
  • resolution_factor (int, default=2): Controls the depth of multi-resolution analysis

Example Results

When applied to a regression task using the California Housing dataset, the fractal regularization demonstrates improved model generalization compared to standard approaches:

# Plot training curves
plt.plot(train_losses['Fractal'], label="Training Loss (Fractal)")
plt.plot(val_losses['Fractal'], label="Validation Loss (Fractal)")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss with Fractal Regularization")
plt.legend()
plt.show()

Requirements

  • Python >= 3.6
  • PyTorch >= 1.9.0
  • NumPy >= 1.19.0

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fractal_regularization-0.1.1.tar.gz (4.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fractal_regularization-0.1.1-py3-none-any.whl (5.3 kB view details)

Uploaded Python 3

File details

Details for the file fractal_regularization-0.1.1.tar.gz.

File metadata

  • Download URL: fractal_regularization-0.1.1.tar.gz
  • Upload date:
  • Size: 4.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.6

File hashes

Hashes for fractal_regularization-0.1.1.tar.gz
Algorithm Hash digest
SHA256 904bdd551a92eb8988e874583fcc5cbddaad2465bcc36f67bc3f640e5ad25e9b
MD5 79165d0a0b552acd234b8995b0088800
BLAKE2b-256 e0b73c567fe23537f2279a5f8241539beea2b06c5f61f61449a353c5ed98a9bd

See more details on using hashes here.

File details

Details for the file fractal_regularization-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for fractal_regularization-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8279968beb55269fa732d4ce948c2a292eca8284f2e9df2a3c685603d62ae69a
MD5 7c1ab8f24dc43e44f6e6d3842ef1c248
BLAKE2b-256 5171a457eca13eeb26496b3b716c41519cbaf2bed946b4307afd33d3033d7093

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page