Skip to main content

Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling

Project description

GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling

PyPI version Python 3.8+ License: MIT

A PyTorch implementation of smart sampling for efficient deep learning training.

Overview

GRAFT uses gradient information and feature decomposition to select the most informative samples during training, reducing computation time while maintaining model performance.

Features

  • Smart sample selection using gradient-based importance scoring
  • Multi-architecture support (ResNet, ResNeXT, EfficientNet, BERT)
  • Dataset compatibility (CIFAR10/100, TinyImageNet, Caltech256, Medical datasets)
  • Experiment tracking with Weights & Biases integration
  • Carbon footprint tracking with eco2AI
  • Efficient training with reduced computational overhead

Installation

From PyPI (Recommended)

pip install graft-pytorch

With optional dependencies

# For experiment tracking
pip install graft-pytorch[tracking]

# For development
pip install graft-pytorch[dev]

# Everything
pip install graft-pytorch[all]

From Source

git clone https://github.com/ashishjv1/GRAFT.git
cd GRAFT
pip install -e .

Quick Start

Command Line Interface

# Install and train with smart sampling
pip install graft-pytorch

# Basic training with GRAFT sampling on CIFAR-10
graft-train \
    --numEpochs=200 \
    --batch_size=128 \
    --device="cuda" \
    --optimizer="sgd" \
    --lr=0.1 \
    --numClasses=10 \
    --dataset="cifar10" \
    --model="resnet18" \
    --fraction=0.5 \
    --select_iter=25 \
    --warm_start

Python API

import torch
from graft import ModelTrainer, TrainingConfig
from graft.utils.loader import loader

# Load your dataset
trainloader, valloader, trainset, valset = loader(
    dataset="cifar10",
    trn_batch_size=128,
    val_batch_size=128
)

# Configure training with GRAFT
config = TrainingConfig(
    numEpochs=100,
    batch_size=128,
    device="cuda" if torch.cuda.is_available() else "cpu",
    model_name="resnet18",
    dataset_name="cifar10",
    trainloader=trainloader,
    valloader=valloader,
    trainset=trainset,
    optimizer_name="sgd",
    lr=0.1,
    fraction=0.5,         # Use 50% of data per epoch
    selection_iter=25,    # Reselect samples every 25 epochs
    warm_start=True       # Train on full data initially
)

# Train with smart sampling
trainer = ModelTrainer(config, trainloader, valloader, trainset)
train_stats, val_stats = trainer.train()

print(f"Best validation accuracy: {val_stats['best_acc']:.2%}")

Advanced Usage

from graft import feature_sel, sample_selection
import torch.nn as nn

# Custom model and data selection
model = MyCustomModel()
data3 = feature_sel(dataloader, batch_size=128, device="cuda")

# Manual sample selection
selected_indices = sample_selection(
    dataloader, data3, model, model.state_dict(),
    batch_size=128, fraction=0.3, select_iter=10,
    numEpochs=200, device="cuda", dataset="custom"
)

Functionality Overview

Core Components

1. Smart Sample Selection

  • sample_selection(): Selects most informative samples using gradient-based importance
  • feature_sel(): Performs feature decomposition for efficient sampling
  • Reduces training time by 30-50% while maintaining model performance

2. Supported Models

  • Vision Models: ResNet, ResNeXt, EfficientNet, MobileNet, FashionCNN
  • Language Models: BERT for sequence classification
  • Custom Models: Easy integration with any PyTorch model

3. Dataset Support

  • Computer Vision: CIFAR-10/100, TinyImageNet, Caltech256
  • Medical Imaging: Integration with MedMNIST datasets
  • Custom Datasets: Support for any PyTorch DataLoader

4. Training Features

  • Dynamic Sampling: Adaptive sample selection during training
  • Warm Starting: Begin with full dataset, then switch to sampling
  • Experiment Tracking: Built-in WandB integration
  • Carbon Tracking: Monitor environmental impact with eco2AI

Configuration Parameters

Parameter Description Default Options
numEpochs Training epochs 200 Any integer
batch_size Batch size 128 32, 64, 128, 256+
device Computing device "cuda" "cpu", "cuda"
model Model architecture "resnet18" "resnet18/50", "resnext", "efficientnet"
fraction Data sampling ratio 0.5 0.1 - 1.0
select_iter Reselection frequency 25 Any integer
optimizer Optimization algorithm "sgd" "sgd", "adam"
lr Learning rate 0.1 0.001 - 0.1
warm_start Use full data initially False True/False
decomp Decomposition backend "numpy" "numpy", "torch"

Performance Benefits

  • Speed: 30-50% faster training time
  • Memory: Reduced memory usage through smart sampling
  • Accuracy: Maintains or improves model performance
  • Efficiency: Lower carbon footprint and energy consumption

Package Structure

graft-pytorch/
├── graft/
│   ├── __init__.py          # Main package exports
│   ├── trainer.py           # Training orchestration
│   ├── genindices.py        # Sample selection algorithms
│   ├── decompositions.py    # Feature decomposition
│   ├── models/              # Supported architectures
│   │   ├── resnet.py        # ResNet implementations  
│   │   ├── efficientnet.py  # EfficientNet models
│   │   └── BERT_model.py    # BERT for classification
│   └── utils/               # Utility functions
│       ├── loader.py        # Dataset loaders
│       └── model_mapper.py  # Model selection
├── tests/                   # Comprehensive test suite
├── examples/                # Usage examples
└── OIDC_SETUP.md           # Deployment configuration

Contributing

We welcome contributions! Please see our contribution guidelines for details.

Development Setup

# Clone the repository
git clone https://github.com/ashishjv1/GRAFT.git
cd GRAFT

# Install in development mode
pip install -e .[dev]

# Run tests
pytest tests/ -v

# Run linting
flake8 graft/ tests/

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use GRAFT in your research, please cite our paper:

@misc{jha2025graftgradientawarefastmaxvol,
  title         = {GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling},
  author        = {Ashish Jha and Anh Huy Phan and Razan Dibo and Valentin Leplat},
  year          = {2025},
  eprint        = {2508.13653},
  archivePrefix = {arXiv},
  primaryClass  = {cs.LG},
  url           = {https://arxiv.org/abs/2508.13653}
}

Acknowledgments

  • Built using PyTorch
  • Inspired by MaxVol techniques for data sampling
  • Special thanks to the open-source community

PyPI Package: graft-pytorch
Paper: arXiv:2508.13653
Issues: GitHub Issues
Contact: Ashish Jha

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graft_pytorch-0.1.7.tar.gz (47.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

graft_pytorch-0.1.7-py3-none-any.whl (38.0 kB view details)

Uploaded Python 3

File details

Details for the file graft_pytorch-0.1.7.tar.gz.

File metadata

  • Download URL: graft_pytorch-0.1.7.tar.gz
  • Upload date:
  • Size: 47.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for graft_pytorch-0.1.7.tar.gz
Algorithm Hash digest
SHA256 bc7053fd47831af1c7253d4722e418c5a8e2748bea0236cfddad80758e91fead
MD5 deb83c7332135773d90282afffa963f0
BLAKE2b-256 2e750f2020e091f9e99d3fd4700d19635697d0e94ac0784c35e8d7142ef5200d

See more details on using hashes here.

Provenance

The following attestation bundles were made for graft_pytorch-0.1.7.tar.gz:

Publisher: publish.yml on ashishjv1/GRAFT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file graft_pytorch-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: graft_pytorch-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 38.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for graft_pytorch-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 0aae3ed9a7775de385c850c7f632bf6e19524458a6965e60179b2dc9a0bd319a
MD5 ab2098875beaea20ba7ca4621bb51664
BLAKE2b-256 0f6d46fd93a177e5b921dee7e9e30c6b825d689d03dde5b51b6fa61945dc8b94

See more details on using hashes here.

Provenance

The following attestation bundles were made for graft_pytorch-0.1.7-py3-none-any.whl:

Publisher: publish.yml on ashishjv1/GRAFT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page