Distribution-Aware Data Augmentation for Imbalanced Learning
Project description
๐ฏ DistAwareAug
Distribution-Aware Data Augmentation for Imbalanced Learning
Intelligent oversampling that preserves statistical distributions and ensures sample diversity
Features โข Installation โข Quick Start โข Documentation โข Contributing โข Citation
๐ Overview
DistAwareAug introduces a new paradigm for handling imbalanced datasets: statistically-governed augmentation. Unlike traditional methods like SMOTE that interpolate between samples, DistAwareAug:
- ๐ Learns the statistical distribution of minority class features (mean, variance, covariance)
- ๐ฒ Generates synthetic samples from fitted distributions (KDE or Gaussian)
- ๐ฏ Ensures diversity through distance-based filtering
- โ Preserves feature bounds and avoids unrealistic outliers
Why DistAwareAug?
| Feature | SMOTE | BorderlineSMOTE | ADASYN | DistAwareAug |
|---|---|---|---|---|
| Preserves distributions | โ | โ | โ | โ |
| Diversity control | โ ๏ธ | โ ๏ธ | โ ๏ธ | โ |
| Flexible sampling (add/target) | โ | โ | โ | โ |
| Avoids interpolation artifacts | โ | โ | โ | โ |
| Statistical governance | โ | โ | โ | โ |
| Supports downsampling | โ | โ | โ | โ |
| Distribution methods (KDE/Gaussian) | โ | โ | โ | โ |
โจ Features
- ๐ฌ Distribution Fitting: KDE (Kernel Density Estimation) or Gaussian distributions
- ๐ Distance Metrics: Euclidean, Manhattan, Cosine, Minkowski, and more
- ๐ฒ Diversity Control: Configurable threshold for sample uniqueness
- ๐ฏ Flexible Sampling Modes:
- 'add' mode: Add N samples to existing class count
- 'target' mode: Target absolute sample counts (supports both upsampling and downsampling)
- ๐ scikit-learn Compatible: Standard
fit_resample()API - โก Performance Optimized: KD-Tree-based diversity checking (10-13x faster than v0.1.0)
- ๐ Quality Metrics: Built-in diversity scoring and validation
- ๐ก๏ธ Robust: Handles edge cases, singular matrices, and various data types
๐ What's New in v0.2.0
Major Performance Improvements
- 10-13x Speedup: Replaced O(nยฒ) diversity checking with KD-Tree (O(log n))
- Batch Generation: Increased batch size for reduced Python overhead
- Vectorized Operations: Optimized clipping and distance calculations
- Parallel Processing: Multi-core neighbor queries with
n_jobs=-1
Key Changes
- โ KD-Tree Diversity Checking: Checks ALL synthetic samples efficiently
- โ Better Documentation: Clear parameter explanations and examples
- โ All Tests Passing: Comprehensive test coverage
Performance Comparison
Benchmark (1000 samples, 20 features, 9:1 imbalance):
SMOTE: 0.007s
DistAwareAug: 0.05-0.08s (7-12x slower than SMOTE)
v0.1.0 was: 0.6-0.7s (91x slower than SMOTE) โ
Result: 10-13x speedup while maintaining quality! ๐
๐ Best Practices & Recommendations
โ When to Use DistAwareAug
Ideal Use Cases:
- Moderate imbalance (2:1 to 50:1 ratio)
- Multi-modal distributions (data with multiple clusters)
- High-dimensional data where SMOTE may create unrealistic samples
- When sample quality matters more than generation speed
- Research applications requiring statistical rigor
Example Scenarios:
- Medical diagnosis with rare diseases (class imbalance 5:1 to 30:1)
- Fraud detection with multiple fraud patterns
- Customer churn prediction with diverse customer segments
โ ๏ธ When NOT to Use DistAwareAug
Not Recommended For:
- Extreme imbalance (>100:1) - Use SMOTE/ADASYN instead
- Very few minority samples (<50 samples in high dimensions)
- Speed-critical applications where 7-12x slower than SMOTE is unacceptable
- Simple linear separability where SMOTE works fine
๐ฏ Parameter Tuning Guide
| Parameter | Low Imbalance (2:1 to 10:1) | Moderate (10:1 to 50:1) | Notes |
|---|---|---|---|
diversity_threshold |
0.05 - 0.1 | 0.1 - 0.2 | Higher = more diverse samples |
distribution_method |
'kde' | 'kde' or 'gaussian' | KDE for multi-modal, Gaussian for speed |
distance_metric |
'euclidean' | 'euclidean' or 'manhattan' | Manhattan robust to outliers |
Pro Tips:
- Scale your features before using DistAwareAug (use
StandardScaler) - Start with 'gaussian' for speed, switch to 'kde' if needed
- Tune
diversity_thresholdbased on your feature scale - Use
sampling_mode='target'for precise control over class distribution
๐ Installation
From Source (Recommended for Development)
# Clone the repository
git clone https://github.com/Ayo-Cyber/DistAwareAug.git
cd DistAwareAug
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install in editable mode with all dependencies
pip install -e ".[all]"
For Users (Once Published to PyPI)
pip install distawareaug
Minimal Installation
pip install distawareaug
With Optional Dependencies
# For development (testing, linting, formatting)
pip install distawareaug[dev]
# For running examples
pip install distawareaug[examples]
# For building documentation
pip install distawareaug[docs]
# Everything
pip install distawareaug[all]
๐ฏ Quick Start
Basic Usage
from distawareaug import DistAwareAugmentor
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
# Create imbalanced dataset
X, y = make_classification(
n_samples=1000,
n_features=20,
weights=[0.9, 0.1], # 90% majority, 10% minority
random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
# Initialize DistAwareAugmentor
augmentor = DistAwareAugmentor(
sampling_strategy='auto', # Balance all classes
diversity_threshold=0.1, # Minimum distance between samples
distribution_method='kde', # or 'gaussian' for speed
random_state=42
)
# Oversample the training data
X_resampled, y_resampled = augmentor.fit_resample(X_train, y_train)
# Train classifier on balanced data
clf = RandomForestClassifier(random_state=42)
clf.fit(X_resampled, y_resampled)
# Evaluate
score = clf.score(X_test, y_test)
print(f"Test Accuracy: {score:.4f}")
Advanced Usage
from distawareaug import DistAwareAugmentor, DistanceMetrics
from sklearn.preprocessing import StandardScaler
# Scale features for better diversity control
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
# Example 1: Add mode (adds N samples to existing count)
augmentor_add = DistAwareAugmentor(
sampling_strategy={0: 0, 1: 500}, # Add 500 samples to class 1, no change to class 0
sampling_mode='add', # Default: adds to existing count
diversity_threshold=0.15,
distribution_method='gaussian',
distance_metric='manhattan',
random_state=42
)
X_resampled, y_resampled = augmentor_add.fit_resample(X_train_scaled, y_train)
# Example 2: Target mode (targets absolute count, can upsample or downsample)
augmentor_target = DistAwareAugmentor(
sampling_strategy={0: 1000, 1: 1000}, # Target exactly 1000 samples for each class
sampling_mode='target', # Will upsample minority, downsample majority
diversity_threshold=0.15,
distribution_method='gaussian',
random_state=42
)
X_balanced, y_balanced = augmentor_target.fit_resample(X_train_scaled, y_train)
# Analyze diversity
dm = DistanceMetrics(metric='euclidean')
diversity = dm.diversity_score(X_resampled[len(X_train):])
print(f"Synthetic Sample Diversity: {diversity:.4f}")
Understanding Sampling Modes
DistAwareAug offers two sampling modes that control how sampling_strategy values are interpreted:
'add' Mode (Default)
Adds N samples to the existing class count:
# If class 1 has 100 samples originally
augmentor = DistAwareAugmentor(
sampling_strategy={1: 500},
sampling_mode='add' # Default
)
# Result: class 1 will have 100 + 500 = 600 samples
'target' Mode
Targets an absolute number of samples (can upsample or downsample):
# If class 0 has 5000 samples and class 1 has 500 samples
augmentor = DistAwareAugmentor(
sampling_strategy={0: 1000, 1: 1000},
sampling_mode='target'
)
# Result: class 0 downsampled to 1000, class 1 upsampled to 1000
Use cases:
- 'add' mode: When you want to generate a specific number of additional samples
- 'target' mode: When you want to balance classes to exact counts
๐ Documentation
Core Components
DistAwareAugmentor
Main class for oversampling imbalanced datasets.
Parameters:
sampling_strategy(str or dict, default='auto'): How to balance classes'auto': Balance all classes to majority class sizedict: Specify number of samples per class, e.g.,{0: 100, 1: 200}
sampling_mode(str, default='add'): How to interpretsampling_strategydict values'add': Add N samples to existing class count (e.g.,{1: 500}adds 500 to class 1)'target': Target N total samples for class (e.g.,{1: 500}results in exactly 500 samples)- Note:
'target'mode supports both upsampling and downsampling
diversity_threshold(float, default=0.1): Minimum distance for sample acceptance- Important: Scale your features for consistent behavior!
distribution_method(str, default='kde'): Distribution fitting method'kde': Kernel Density Estimation (more accurate, slower)'gaussian': Multivariate Gaussian (faster, assumes normality)
distance_metric(str, default='euclidean'): Distance metric for diversity- Options:
'euclidean','manhattan','cosine','minkowski', etc.
- Options:
random_state(int, default=None): Random seed for reproducibility
Methods:
fit(X, y): Fit the augmentor to training dataresample(X, y): Generate synthetic samplesfit_resample(X, y): Fit and resample in one step
DistanceMetrics
Utilities for computing distances and diversity scores.
Key Methods:
compute_distances(X, Y): Pairwise distances between samplesnearest_neighbor_distances(X, Y): Distance to nearest neighbordiversity_score(samples, reference): Measure sample diversityfilter_diverse_samples(samples, threshold): Keep only diverse samples
DistributionFitter
Fits statistical distributions to feature data.
Supported Distributions:
'kde': Kernel Density Estimation'gaussian': Multivariate Gaussian'uniform': Uniform distribution (for testing)
๐งช Testing
Run All Tests
# Run tests with coverage
pytest
# Run with verbose output
pytest -v
# Run specific test file
pytest tests/test_augmentor.py
# Run tests matching pattern
pytest -k "test_diversity"
Run Tests with Coverage Report
# Generate HTML coverage report
pytest --cov=distawareaug --cov-report=html
# Open coverage report
open htmlcov/index.html # macOS
xdg-open htmlcov/index.html # Linux
start htmlcov/index.html # Windows
Test Structure
tests/
โโโ test_augmentor.py # Tests for DistAwareAugmentor
โโโ test_distance.py # Tests for distance metrics
โโโ test_distribution.py # Tests for distribution fitting
โโโ __pycache__/
Writing New Tests
import pytest
import numpy as np
from distawareaug import DistAwareAugmentor
def test_my_feature():
"""Test description."""
X = np.random.randn(100, 5)
y = np.random.randint(0, 2, 100)
augmentor = DistAwareAugmentor(random_state=42)
X_resampled, y_resampled = augmentor.fit_resample(X, y)
assert len(X_resampled) >= len(X)
assert len(np.unique(y_resampled)) == len(np.unique(y))
Run CI Tests Locally
Before pushing, run the same checks that GitHub Actions will run:
# Run all CI checks locally (formatting, linting, tests)
sh run_ci_tests.sh
This script will:
- โ Check code formatting with Black
- โ Check import sorting with isort
- โ Run linting with flake8
- โ Run all tests with pytest
- โ Generate coverage report
Auto-Fix Linting Issues
If you have linting errors (unused imports, variables, etc.):
# Automatically remove unused imports and variables
python fix_linting.py
This will:
- Remove unused imports
- Remove unused variables
- Format code with Black
- Sort imports with isort
Then run sh run_ci_tests.sh again to verify!
๐ค Contributing
We welcome contributions! Here's how to get started:
Setting Up Development Environment
# Fork and clone the repository
git clone https://github.com/YOUR_USERNAME/DistAwareAug.git
cd DistAwareAug
# Create virtual environment
python -m venv venv
source venv/bin/activate # Windows: venv\Scripts\activate
# Install in development mode with all dependencies
pip install -e ".[dev]"
# Install pre-commit hooks (optional but recommended)
pre-commit install
Development Workflow
-
Create a new branch for your feature/fix:
git checkout -b feature/your-feature-name
-
Make your changes and ensure code quality:
# Auto-fix common issues python fix_linting.py # Or manually format black distawareaug tests isort distawareaug tests flake8 distawareaug tests
-
Run CI checks locally:
# Run all checks (formatting, linting, tests) sh run_ci_tests.sh
-
Write/update tests:
# Run tests pytest -v # Check coverage pytest --cov=distawareaug --cov-report=term-missing
-
Update documentation if needed:
- Update README.md for user-facing changes
- Update docstrings for library reference changes
- Add examples for new features
-
Commit your changes:
git add . git commit -m "feat: add your feature description"
Follow Conventional Commits:
feat:New featurefix:Bug fixdocs:Documentation changestest:Test changesrefactor:Code refactoringperf:Performance improvementschore:Maintenance tasks
-
Push and create Pull Request:
git push origin feature/your-feature-name
Then open a PR on GitHub with a clear description.
Pro tip: Use make commands for common tasks:
make format # Format code with black and isort
make lint # Check linting
make test # Run tests
make test-cov # Run tests with coverage
make check # Run all checks
make clean # Clean build artifacts
See CONTRIBUTING.md for detailed guidelines.
Code Style Guidelines
- Line length: 100 characters (enforced by Black)
- Docstrings: Google-style or NumPy-style
- Type hints: Encouraged but not required
- Naming:
snake_casefor functions/variablesPascalCasefor classesUPPER_CASEfor constants
Example Contribution Areas
- ๐ Bug Fixes: Found an issue? Submit a PR!
- โจ New Features:
- Additional distribution methods
- New distance metrics
- Performance optimizations
- ๐ Documentation: Improve examples, tutorials, docstrings
- ๐งช Tests: Increase test coverage
- ๐จ Examples: Add Jupyter notebooks demonstrating use cases
Questions?
- Open an Issue for bugs or feature requests
- Start a Discussion for questions
๐ Examples
Explore comprehensive examples in the examples/ directory:
Available Notebooks
demo_synthetic.ipynb: Introduction to DistAwareAug with synthetic datacompare_smote.ipynb: Benchmark comparison with SMOTE, ADASYN, etc.comprehensive_test.ipynb: Full performance analysis with threshold optimization
Running Examples
# Install example dependencies
pip install ".[examples]"
# Start Jupyter
jupyter notebook examples/
๐ง Troubleshooting
Common Issues
1. Low Diversity / Few Samples Generated
Problem: diversity_threshold is too high for your feature scales.
Solution: Scale your features first!
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
augmentor = DistAwareAugmentor(diversity_threshold=0.1)
X_resampled, y_resampled = augmentor.fit_resample(X_train_scaled, y_train)
2. Singular Matrix Errors
Problem: Too few samples or perfectly correlated features.
Solution: DistAwareAug handles this automatically, but you can:
- Use
distribution_method='gaussian'which adds regularization - Ensure sufficient samples (>10 per class recommended)
- Remove perfectly correlated features
3. Slow Performance
Problem: KDE is slow for large datasets or diversity checking takes too long.
Solutions:
# Solution 1: Use Gaussian method (3-5x faster than KDE)
augmentor = DistAwareAugmentor(distribution_method='gaussian')
# Solution 2: Lower diversity threshold (faster, more samples accepted)
augmentor = DistAwareAugmentor(diversity_threshold=0.05)
# Solution 3: Combine both for maximum speed
augmentor = DistAwareAugmentor(
distribution_method='gaussian',
diversity_threshold=0.05
)
Note: DistAwareAug is typically 5-15x slower than SMOTE due to distribution fitting and diversity enforcement. This is expected and provides higher quality synthetic samples.
4. Import Errors
Problem: Missing dependencies.
Solution:
pip install -e ".[all]"
๐ Performance Tips
Speed Considerations
DistAwareAug is typically 5-15x slower than SMOTE due to:
- Distribution fitting (KDE or Gaussian)
- Diversity enforcement (checks samples against random subsample for efficiency)
For reference on a 5,000 sample dataset (9:1 imbalance):
- SMOTE: ~0.007s
- ADASYN: ~0.035s (5x slower than SMOTE)
- DistAwareAug: ~0.05-0.08s (7-12x slower than SMOTE)
This trade-off provides better quality synthetic data that preserves statistical distributions.
Optimization Tips
- Scale your features with
StandardScalerfor consistentdiversity_thresholdbehavior - Use
distribution_method='gaussian'for large datasets (3-5x faster than KDE)augmentor = DistAwareAugmentor(distribution_method='gaussian') # Faster
- Adjust
diversity_thresholdbased on your needs:- Higher (0.2-0.5): More diverse samples, fewer total samples, slower
- Lower (0.05-0.1): More samples, less diversity, faster
- Set
random_statefor reproducible results - Start with small datasets to tune parameters before scaling up
How Diversity Checking Works
DistAwareAug ensures synthetic samples are diverse by checking they are sufficiently far from existing samples. For performance, diversity is checked against a random subsample of up to 200 existing synthetic samples rather than all of them.
Why this works:
- Checking all pairwise distances would be O(nยฒ) - extremely slow for thousands of samples
- Random sampling provides ~95% of the quality with 10x+ better performance
- Similar to statistical polling: you don't need to survey everyone to get accurate results
This approximation is statistically sound and provides excellent quality/speed balance.
๐ License
This project is licensed under the MIT License - see the LICENSE file for details.
๐ Acknowledgments
- Inspired by SMOTE and related oversampling techniques
- Built with scikit-learn, NumPy, SciPy, and pandas
- Thanks to the open-source community
๐ฌ Contact
- Author: Atunrase Ayo
- Email: atunraseayomide@gmail.com
- GitHub: @Ayo-Cyber
- Repository: DistAwareAug
๐ Citation
If you use DistAwareAug in your research, please cite:
@software{distawareaug2025,
author = {Atunrase, Ayo},
title = {DistAwareAug: Distribution-Aware Data Augmentation for Imbalanced Learning},
year = {2025},
publisher = {GitHub},
url = {https://github.com/Ayo-Cyber/DistAwareAug}
}
๐บ๏ธ Roadmap
- Publish to PyPI
- Add more distribution methods (t-distribution, mixture models)
- GPU acceleration for large-scale augmentation
- Web-based interactive demo
- Integration with popular ML frameworks (PyTorch, TensorFlow)
- Automated hyperparameter tuning
- Real-world dataset benchmarks
โญ Star this repo if you find it useful! โญ
Report Bug โข Request Feature โข Contribute
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file distawareaug-0.2.0.tar.gz.
File metadata
- Download URL: distawareaug-0.2.0.tar.gz
- Upload date:
- Size: 817.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
101287217bc7fa977ed8ceb3fec5ffabcd4adef31e37c773e02e6533ce5c9921
|
|
| MD5 |
db6b90f190070f67f5f30c35d72a3c28
|
|
| BLAKE2b-256 |
57db4e47ce021b2331c6898a54ef7c6c5504845c4338603a0edc3ab4f8f6a4e6
|
File details
Details for the file distawareaug-0.2.0-py3-none-any.whl.
File metadata
- Download URL: distawareaug-0.2.0-py3-none-any.whl
- Upload date:
- Size: 24.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1513c60ffefd75fd73243616e389be5b34863c8d0e33a4a759c4dbf89cc2790b
|
|
| MD5 |
c7a4277bcf0ab32f4bc2f7466bb79659
|
|
| BLAKE2b-256 |
6269482b0a2fc5c413789ad160a22afb5a5b0cb9604346347d9951262aa4b746
|