Skip to main content

Adaptive Class Weight Adjustment (ACWA) for Imbalanced Deep Learning

Project description

Adaptive Class Weight Adjustment (ACWA) - Automated Class Balancing for Deep Learning

Python PyTorch License

📖 Table of Contents

🌟 Overview

ACWA is an advanced optimization algorithm designed to automatically adjust class weights during neural network training, particularly effective for imbalanced datasets. Unlike traditional approaches, ACWA dynamically adapts based on real-time performance metrics.

Traditional Methods Limitations:

  • Static class weighting based on frequency
  • Manual oversampling/undersampling
  • Fixed cost-sensitive learning

ACWA Advantages:

  • 🚀 Real-time performance monitoring
  • ⚖️ Dynamic weight adjustment
  • 🎯 Focus on underperforming classes
  • 🤖 No manual intervention needed

✨ Key Features

  • Adaptive Learning: Adjusts weights based on validation performance
  • Smoothing Mechanism: Prevents drastic weight fluctuations
  • Multi-class Support: Works with any number of classes
  • Framework Agnostic: Compatible with PyTorch, TensorFlow, etc.
  • Plug-and-Play: Easy integration into existing pipelines
  • TorchMetrics Integration: Efficient F1-score calculation
  • Dynamic Weight Initialization: Supports inverse class frequency
  • Early Stopping: Prevents overfitting by monitoring validation performance
  • Numerical Stability: Epsilon added to class frequency for robust weight initialization

🧠 Algorithm Design

Core Concept

ACWA operates through a feedback loop:

  1. Monitor class-wise performance
  2. Calculate performance gaps
  3. Adjust weights dynamically

Mathematical Formulation

Performance Error:

error_c = target\_metric - current\_metric_c

Weight Update:

weight_c^{(t+1)} = clip(\beta \cdot weight_c^{(t)} + (1-\beta) \cdot (weight_c^{(t)} + \alpha \cdot error_c), 0.5, 2.0)

Loss Modification:

\mathcal{L} = \sum_{c=1}^C weight_c \cdot \mathcal{L}_c

Hyperparameters

Parameter Description Recommended Value
α Learning rate 0.01-0.05
β Smoothing factor 0.8-0.95
K Update frequency 50-200 batches
Target Performance goal Class-specific

🏆 When to Use ACWA

Ideal Scenarios

  • 🏥 Medical diagnosis (rare disease detection)
  • 💳 Fraud detection
  • ⚠️ Rare event prediction
  • 🛡️ Anomaly detection
  • 📊 Highly imbalanced datasets

Comparison with Alternatives

Method Pros Cons
ACWA Adaptive, automatic Slightly more compute
Class Weighting Simple Static, manual tuning
Resampling Balances data May lose information
Focal Loss Handles hard samples Fixed strategy

💻 Implementation Guide

Installation

pip install acwa-torch

Basic Usage

from acwa import ACWATrainer

# Initialize
trainer = ACWATrainer(
    num_classes=10,
    target_metric=0.85,  # Target F1-score
    alpha=0.02,
    beta=0.9,
    update_freq=100
)

# Training loop
for batch in dataloader:
    # Forward pass
    outputs = model(inputs)
    
    # ACWA-weighted loss
    loss = trainer.get_weighted_loss(outputs, labels)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    # Update metrics
    trainer.update_metrics(outputs, labels)

Advanced Features

# Custom metrics
trainer = ACWATrainer(
    metric_fn=custom_f1_function,
    metric_mode='max'  # or 'min'
)

# Combined with Focal Loss
trainer = ACWATrainer(
    loss_fn=FocalLoss(gamma=2.0),
    ...
)

# Initialize weights using inverse class frequency
class_counts = torch.bincount(torch.tensor(trainset.targets))
class_frequencies = class_counts.float() / (class_counts.sum() + 1e-6)

trainer = ACWATrainer(
    model=model,
    num_classes=10,
    class_frequencies=class_frequencies
)

# Early stopping example
best_f1 = 0
early_stop_counter = 0
patience = 5

for epoch in range(num_epochs):
    # ...training logic...
    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save(model.state_dict(), 'best_model.pth')
        early_stop_counter = 0
    else:
        early_stop_counter += 1

    if early_stop_counter >= patience:
        print("Early stopping triggered.")
        break

📚 Usage for Different Datasets

Vision Task (CIFAR-100)

  1. Prepare the dataset:
    from demo_cifar100 import main
    main()
    

NLP Task (IMDB with Hugging Face)

  1. Use Hugging Face's datasets library to load the IMDB dataset:
    from demo_huggingface import main
    main()
    

Time-Series Task (UCR)

  1. Load UCR dataset using tslearn or a similar library.

❓ FAQ

  1. How to integrate Hugging Face models?

    • Use the transformers library and wrap the model with ACWATrainer.
  2. What if F1-score does not improve?

    • Ensure update_freq is set appropriately (20-50).
    • Check if the dataset is properly preprocessed.

📝 Best Practices

  1. Validation Set: Ensure representative distribution
  2. Initial Weights: Start with uniform weights (1.0)
  3. Hyperparameter Tuning:
    • Start with α=0.01, β=0.9
    • Adjust based on convergence
  4. Monitoring: Track weight evolution during training
  5. Combination Strategies:
    • Works well with data augmentation
    • Can be combined with focal loss
# Example weight evolution plot
plt.plot(weight_history)
plt.title('ACWA Weight Adjustment')
plt.xlabel('Update Steps')
plt.ylabel('Class Weight')
plt.show()

🏅 Benchmark Results

CIFAR-10 (Imbalanced)

Method Accuracy Macro F1 Training Time
ACWA (Version 3) 86.3% 0.781 0.7h
ACWA (Final) 87.5% 0.799 0.65h

🤝 Contributing

We welcome contributions! Please see our:

Future Improvements

  1. Unit Testing:

    • Add test cases for edge scenarios (e.g., empty classes, small batch sizes).
    • Ensure compatibility with various datasets and imbalance ratios.
  2. Distributed Training:

    • Implement support for multi-GPU setups using torch.nn.parallel.DistributedDataParallel.
    • Synchronize metrics across GPUs for consistent weight updates.
  3. Additional Frameworks:

    • Extend support to TensorFlow/Keras for broader adoption.

📜 License

MIT License - Free for academic and commercial use

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

acwa_trainer-1.0.0.tar.gz (10.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

acwa_trainer-1.0.0-py3-none-any.whl (11.6 kB view details)

Uploaded Python 3

File details

Details for the file acwa_trainer-1.0.0.tar.gz.

File metadata

  • Download URL: acwa_trainer-1.0.0.tar.gz
  • Upload date:
  • Size: 10.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for acwa_trainer-1.0.0.tar.gz
Algorithm Hash digest
SHA256 52e5aa06b5271ea2dd007739c0d0fd93fc5cce01de0924c95000cac6a017712e
MD5 e17c024e0f38282c22aff054baa79de7
BLAKE2b-256 b25492257c5bdc99a2cc01dcba4361190a512619215ba3740f6f433b20207e6e

See more details on using hashes here.

File details

Details for the file acwa_trainer-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: acwa_trainer-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 11.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for acwa_trainer-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 126b566ae51883fe37fdaf106e8b82fea8ea802b027fb3937c06420d869f8267
MD5 a2a389f6640501205605c8ddfd909b5c
BLAKE2b-256 3f06b749fece6e43d479d51b5cb4e83bae5f1f6bd4db57d818acc6b5e0cec75d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page