Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
Project description
GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
A PyTorch implementation of smart sampling for efficient deep learning training.
Overview
GRAFT uses gradient information and feature decomposition to select the most informative samples during training, reducing computation time while maintaining model performance.
Features
- Smart sample selection using gradient-based importance scoring
- Multi-architecture support (ResNet, ResNeXT, EfficientNet, BERT)
- Dataset compatibility (CIFAR10/100, TinyImageNet, Caltech256, Medical datasets)
- Experiment tracking with Weights & Biases integration
- Carbon footprint tracking with eco2AI
- Efficient training with reduced computational overhead
Installation
From PyPI (Recommended)
pip install graft-pytorch
With optional dependencies
# For experiment tracking
pip install graft-pytorch[tracking]
# For development
pip install graft-pytorch[dev]
# Everything
pip install graft-pytorch[all]
From Source
git clone https://github.com/ashishjv1/GRAFT.git
cd GRAFT
pip install -e .
Quick Start
Command Line Interface
# Install and train with smart sampling
pip install graft-pytorch
# Basic training with GRAFT sampling on CIFAR-10
graft-train \
--numEpochs=200 \
--batch_size=128 \
--device="cuda" \
--optimizer="sgd" \
--lr=0.1 \
--numClasses=10 \
--dataset="cifar10" \
--model="resnet18" \
--fraction=0.5 \
--select_iter=25 \
--warm_start
Python API
import torch
from graft import ModelTrainer, TrainingConfig
from graft.utils.loader import loader
# Load your dataset
trainloader, valloader, trainset, valset = loader(
dataset="cifar10",
trn_batch_size=128,
val_batch_size=128
)
# Configure training with GRAFT
config = TrainingConfig(
numEpochs=100,
batch_size=128,
device="cuda" if torch.cuda.is_available() else "cpu",
model_name="resnet18",
dataset_name="cifar10",
trainloader=trainloader,
valloader=valloader,
trainset=trainset,
optimizer_name="sgd",
lr=0.1,
fraction=0.5, # Use 50% of data per epoch
selection_iter=25, # Reselect samples every 25 epochs
warm_start=True # Train on full data initially
)
# Train with smart sampling
trainer = ModelTrainer(config, trainloader, valloader, trainset)
train_stats, val_stats = trainer.train()
print(f"Best validation accuracy: {val_stats['best_acc']:.2%}")
Advanced Usage
from graft import feature_sel, sample_selection
import torch.nn as nn
# Custom model and data selection
model = MyCustomModel()
data3 = feature_sel(dataloader, batch_size=128, device="cuda")
# Manual sample selection
selected_indices = sample_selection(
dataloader, data3, model, model.state_dict(),
batch_size=128, fraction=0.3, select_iter=10,
numEpochs=200, device="cuda", dataset="custom"
)
Functionality Overview
Core Components
1. Smart Sample Selection
sample_selection(): Selects most informative samples using gradient-based importancefeature_sel(): Performs feature decomposition for efficient sampling- Reduces training time by 30-50% while maintaining model performance
2. Supported Models
- Vision Models: ResNet, ResNeXt, EfficientNet, MobileNet, FashionCNN
- Language Models: BERT for sequence classification
- Custom Models: Easy integration with any PyTorch model
3. Dataset Support
- Computer Vision: CIFAR-10/100, TinyImageNet, Caltech256
- Medical Imaging: Integration with MedMNIST datasets
- Custom Datasets: Support for any PyTorch DataLoader
4. Training Features
- Dynamic Sampling: Adaptive sample selection during training
- Warm Starting: Begin with full dataset, then switch to sampling
- Experiment Tracking: Built-in WandB integration
- Carbon Tracking: Monitor environmental impact with eco2AI
Configuration Parameters
| Parameter | Description | Default | Options |
|---|---|---|---|
numEpochs |
Training epochs | 200 | Any integer |
batch_size |
Batch size | 128 | 32, 64, 128, 256+ |
device |
Computing device | "cuda" | "cpu", "cuda" |
model |
Model architecture | "resnet18" | "resnet18/50", "resnext", "efficientnet" |
fraction |
Data sampling ratio | 0.5 | 0.1 - 1.0 |
select_iter |
Reselection frequency | 25 | Any integer |
optimizer |
Optimization algorithm | "sgd" | "sgd", "adam" |
lr |
Learning rate | 0.1 | 0.001 - 0.1 |
warm_start |
Use full data initially | False | True/False |
decomp |
Decomposition backend | "numpy" | "numpy", "torch" |
Performance Benefits
- Speed: 30-50% faster training time
- Memory: Reduced memory usage through smart sampling
- Accuracy: Maintains or improves model performance
- Efficiency: Lower carbon footprint and energy consumption
Package Structure
graft-pytorch/
├── graft/
│ ├── __init__.py # Main package exports
│ ├── trainer.py # Training orchestration
│ ├── genindices.py # Sample selection algorithms
│ ├── decompositions.py # Feature decomposition
│ ├── models/ # Supported architectures
│ │ ├── resnet.py # ResNet implementations
│ │ ├── efficientnet.py # EfficientNet models
│ │ └── BERT_model.py # BERT for classification
│ └── utils/ # Utility functions
│ ├── loader.py # Dataset loaders
│ └── model_mapper.py # Model selection
├── tests/ # Comprehensive test suite
├── examples/ # Usage examples
└── OIDC_SETUP.md # Deployment configuration
Contributing
We welcome contributions! Please see our contribution guidelines for details.
Development Setup
# Clone the repository
git clone https://github.com/ashishjv1/GRAFT.git
cd GRAFT
# Install in development mode
pip install -e .[dev]
# Run tests
pytest tests/ -v
# Run linting
flake8 graft/ tests/
License
This project is licensed under the MIT License - see the LICENSE file for details.
Citation
If you use GRAFT in your research, please cite our paper:
@misc{jha2025graftgradientawarefastmaxvol,
title = {GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling},
author = {Ashish Jha and Anh Huy Phan and Razan Dibo and Valentin Leplat},
year = {2025},
eprint = {2508.13653},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2508.13653}
}
Acknowledgments
- Built using PyTorch
- Inspired by MaxVol techniques for data sampling
- Special thanks to the open-source community
PyPI Package: graft-pytorch
Paper: arXiv:2508.13653
Issues: GitHub Issues
Contact: Ashish Jha
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file graft_pytorch-1.0.0.tar.gz.
File metadata
- Download URL: graft_pytorch-1.0.0.tar.gz
- Upload date:
- Size: 47.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5028ce37ab8d4f69b42019c7172dacfa33f28326697dc4f6c04fc82ccebaa2ac
|
|
| MD5 |
9bfb6336cdd4e06f96589afd32c77412
|
|
| BLAKE2b-256 |
494c1e114d1c330bd5dcc537edd776a92ebd0e17b6b9a29eb78bf02d3ca474c3
|
Provenance
The following attestation bundles were made for graft_pytorch-1.0.0.tar.gz:
Publisher:
publish.yml on ashishjv1/GRAFT
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
graft_pytorch-1.0.0.tar.gz -
Subject digest:
5028ce37ab8d4f69b42019c7172dacfa33f28326697dc4f6c04fc82ccebaa2ac - Sigstore transparency entry: 481878019
- Sigstore integration time:
-
Permalink:
ashishjv1/GRAFT@578a6195e462150e7c03b3e855d66c9af17668eb -
Branch / Tag:
refs/tags/v1.0.0 - Owner: https://github.com/ashishjv1
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@578a6195e462150e7c03b3e855d66c9af17668eb -
Trigger Event:
release
-
Statement type:
File details
Details for the file graft_pytorch-1.0.0-py3-none-any.whl.
File metadata
- Download URL: graft_pytorch-1.0.0-py3-none-any.whl
- Upload date:
- Size: 38.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0f030c67c7d3e8e9f41f2bc210f8d65e4507cee65e3a2eeedf0e67fbe061cfa3
|
|
| MD5 |
6d4226e579d21d1274b0846236ad36f8
|
|
| BLAKE2b-256 |
c32d4ad46a1cef36021b9c9e83a134fbe199a11f4add1e017df7fa02aa0e73b4
|
Provenance
The following attestation bundles were made for graft_pytorch-1.0.0-py3-none-any.whl:
Publisher:
publish.yml on ashishjv1/GRAFT
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
graft_pytorch-1.0.0-py3-none-any.whl -
Subject digest:
0f030c67c7d3e8e9f41f2bc210f8d65e4507cee65e3a2eeedf0e67fbe061cfa3 - Sigstore transparency entry: 481878020
- Sigstore integration time:
-
Permalink:
ashishjv1/GRAFT@578a6195e462150e7c03b3e855d66c9af17668eb -
Branch / Tag:
refs/tags/v1.0.0 - Owner: https://github.com/ashishjv1
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@578a6195e462150e7c03b3e855d66c9af17668eb -
Trigger Event:
release
-
Statement type: