Adaptive Class Weight Adjustment (ACWA) for Imbalanced Deep Learning
Project description
Adaptive Class Weight Adjustment (ACWA) - Automated Class Balancing for Deep Learning
📖 Table of Contents
- Overview
- Key Features
- Algorithm Design
- When to Use ACWA
- Implementation Guide
- Usage for Different Datasets
- FAQ
- Best Practices
- Benchmark Results
- Contributing
🌟 Overview
ACWA is an advanced optimization algorithm designed to automatically adjust class weights during neural network training, particularly effective for imbalanced datasets. Unlike traditional approaches, ACWA dynamically adapts based on real-time performance metrics.
Traditional Methods Limitations:
- Static class weighting based on frequency
- Manual oversampling/undersampling
- Fixed cost-sensitive learning
ACWA Advantages:
- 🚀 Real-time performance monitoring
- ⚖️ Dynamic weight adjustment
- 🎯 Focus on underperforming classes
- 🤖 No manual intervention needed
✨ Key Features
- Adaptive Learning: Adjusts weights based on validation performance
- Smoothing Mechanism: Prevents drastic weight fluctuations
- Multi-class Support: Works with any number of classes
- Framework Agnostic: Compatible with PyTorch, TensorFlow, etc.
- Plug-and-Play: Easy integration into existing pipelines
- TorchMetrics Integration: Efficient F1-score calculation
- Dynamic Weight Initialization: Supports inverse class frequency
- Early Stopping: Prevents overfitting by monitoring validation performance
- Numerical Stability: Epsilon added to class frequency for robust weight initialization
🧠 Algorithm Design
Core Concept
ACWA operates through a feedback loop:
- Monitor class-wise performance
- Calculate performance gaps
- Adjust weights dynamically
Mathematical Formulation
Performance Error:
error_c = target\_metric - current\_metric_c
Weight Update:
weight_c^{(t+1)} = clip(\beta \cdot weight_c^{(t)} + (1-\beta) \cdot (weight_c^{(t)} + \alpha \cdot error_c), 0.5, 2.0)
Loss Modification:
\mathcal{L} = \sum_{c=1}^C weight_c \cdot \mathcal{L}_c
Hyperparameters
| Parameter | Description | Recommended Value |
|---|---|---|
| α | Learning rate | 0.01-0.05 |
| β | Smoothing factor | 0.8-0.95 |
| K | Update frequency | 50-200 batches |
| Target | Performance goal | Class-specific |
🏆 When to Use ACWA
Ideal Scenarios
- 🏥 Medical diagnosis (rare disease detection)
- 💳 Fraud detection
- ⚠️ Rare event prediction
- 🛡️ Anomaly detection
- 📊 Highly imbalanced datasets
Comparison with Alternatives
| Method | Pros | Cons |
|---|---|---|
| ACWA | Adaptive, automatic | Slightly more compute |
| Class Weighting | Simple | Static, manual tuning |
| Resampling | Balances data | May lose information |
| Focal Loss | Handles hard samples | Fixed strategy |
💻 Implementation Guide
Installation
pip install acwa-torch
Basic Usage
from acwa import ACWATrainer
# Initialize
trainer = ACWATrainer(
num_classes=10,
target_metric=0.85, # Target F1-score
alpha=0.02,
beta=0.9,
update_freq=100
)
# Training loop
for batch in dataloader:
# Forward pass
outputs = model(inputs)
# ACWA-weighted loss
loss = trainer.get_weighted_loss(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
# Update metrics
trainer.update_metrics(outputs, labels)
Advanced Features
# Custom metrics
trainer = ACWATrainer(
metric_fn=custom_f1_function,
metric_mode='max' # or 'min'
)
# Combined with Focal Loss
trainer = ACWATrainer(
loss_fn=FocalLoss(gamma=2.0),
...
)
# Initialize weights using inverse class frequency
class_counts = torch.bincount(torch.tensor(trainset.targets))
class_frequencies = class_counts.float() / (class_counts.sum() + 1e-6)
trainer = ACWATrainer(
model=model,
num_classes=10,
class_frequencies=class_frequencies
)
# Early stopping example
best_f1 = 0
early_stop_counter = 0
patience = 5
for epoch in range(num_epochs):
# ...training logic...
if val_f1 > best_f1:
best_f1 = val_f1
torch.save(model.state_dict(), 'best_model.pth')
early_stop_counter = 0
else:
early_stop_counter += 1
if early_stop_counter >= patience:
print("Early stopping triggered.")
break
📚 Usage for Different Datasets
Vision Task (CIFAR-100)
- Prepare the dataset:
from demo_cifar100 import main main()
NLP Task (IMDB with Hugging Face)
- Use Hugging Face's
datasetslibrary to load the IMDB dataset:from demo_huggingface import main main()
Time-Series Task (UCR)
- Load UCR dataset using
tslearnor a similar library.
❓ FAQ
-
How to integrate Hugging Face models?
- Use the
transformerslibrary and wrap the model withACWATrainer.
- Use the
-
What if F1-score does not improve?
- Ensure
update_freqis set appropriately (20-50). - Check if the dataset is properly preprocessed.
- Ensure
📝 Best Practices
- Validation Set: Ensure representative distribution
- Initial Weights: Start with uniform weights (1.0)
- Hyperparameter Tuning:
- Start with α=0.01, β=0.9
- Adjust based on convergence
- Monitoring: Track weight evolution during training
- Combination Strategies:
- Works well with data augmentation
- Can be combined with focal loss
# Example weight evolution plot
plt.plot(weight_history)
plt.title('ACWA Weight Adjustment')
plt.xlabel('Update Steps')
plt.ylabel('Class Weight')
plt.show()
🏅 Benchmark Results
CIFAR-10 (Imbalanced)
| Method | Accuracy | Macro F1 | Training Time |
|---|---|---|---|
| ACWA (Version 3) | 86.3% | 0.781 | 0.7h |
| ACWA (Final) | 87.5% | 0.799 | 0.65h |
🤝 Contributing
We welcome contributions! Please see our:
Future Improvements
-
Unit Testing:
- Add test cases for edge scenarios (e.g., empty classes, small batch sizes).
- Ensure compatibility with various datasets and imbalance ratios.
-
Distributed Training:
- Implement support for multi-GPU setups using
torch.nn.parallel.DistributedDataParallel. - Synchronize metrics across GPUs for consistent weight updates.
- Implement support for multi-GPU setups using
-
Additional Frameworks:
- Extend support to TensorFlow/Keras for broader adoption.
📜 License
MIT License - Free for academic and commercial use
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file acwa_trainer-1.0.0.tar.gz.
File metadata
- Download URL: acwa_trainer-1.0.0.tar.gz
- Upload date:
- Size: 10.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
52e5aa06b5271ea2dd007739c0d0fd93fc5cce01de0924c95000cac6a017712e
|
|
| MD5 |
e17c024e0f38282c22aff054baa79de7
|
|
| BLAKE2b-256 |
b25492257c5bdc99a2cc01dcba4361190a512619215ba3740f6f433b20207e6e
|
File details
Details for the file acwa_trainer-1.0.0-py3-none-any.whl.
File metadata
- Download URL: acwa_trainer-1.0.0-py3-none-any.whl
- Upload date:
- Size: 11.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
126b566ae51883fe37fdaf106e8b82fea8ea802b027fb3937c06420d869f8267
|
|
| MD5 |
a2a389f6640501205605c8ddfd909b5c
|
|
| BLAKE2b-256 |
3f06b749fece6e43d479d51b5cb4e83bae5f1f6bd4db57d818acc6b5e0cec75d
|