DataCull is a modular, light-weight data pruning library containing many dataset pruning (coreset selection) algorithm including the official Implementation of the paper, titled, RCAP: Robust, Class-Aware, Probab ilistic Dynamic Dataset Pruning
Project description
DataCull
DataCull is a a lightweight, flexible PyTorch framework for data pruning during model training. It provides modular, composable components for implementing and experimenting with data pruning algorithms. Since DataCull decouples importance scoring and sampling logic, it allows, for the very first time, mixing and matching the importance criteria and sampling strategies of different pruning methods.
DataCull comes with the official implementation of the RCAP (Robust Class-Aware Probabilistic) dynamic data pruning algorithm.
It also includes the unofficial implementations of the following data pruning algorithms:
- CCS (Coverage-centric Coreset Selection for High Pruning Rates)
- TDDS (Spanning Training Progress: Temporal Dual-Depth Scoring (TDDS) for Enhanced Dataset Pruning)
- MetriQ (Robust Data Pruning: Uncovering and Overcoming Implicit Bias)
- RS2 (Repeated Random Sampling for Minimizing the Time-to-Accuracy of Learning)
Features
- Modular Design: Clean abstractions for datasets, dataloaders, importance scoring, and logging. Decouples importance scoring and sampling logic, allowing you to mix and match the importance criteria and sampling strategies of different pruning methods.
- Multiple Pruning Algorithms: Built-in implementations of state-of-the-art data pruning methods.
- Dynamic and Static Pruning: Support for both per-epoch (or per-n-epochs) re-sampling and one-time pruning.
- Per-Sample Tracking: Automatically track metrics and importance scores for every sample across training epochs.
- PyTorch and PyTorch Lightning Compatible: Drop-in replacements for PyTorch Dataset and DataLoader (no modification to existing workflows).
- Flexible Importance Scoring: Extensible framework for custom importance computation methods.
- Flexible Pruning: Extensible framework for custom pruning logic.
How to install?
pip install datacull
Quick Start
Basic Usage
Here's a minimal example using DataCull with a standard PyTorch dataset for dynamic data pruning:
import torch
from torch.utils.data import DataLoader
from datacull import DCDataset, DCDataLoader, DCLogger, DCImportance
# 1. Wrap your existing dataset
dataset = DCDataset(your_pytorch_dataset)
# 2. Create a logger to track per-sample metrics
logger = DCLogger(trajectory_dir="./trajectory_directory/", save_every_k_epoch=1)
# 3. Create a dataloader that inherits DCDataLoader and implements the compute_subset function
dataloader = YourPruningDataLoader(
dataset=dataset,
pruning_rate=0.2, # Remove 20% of samples
batch_size=32
)
# 4. During training, log metrics and resample
for epoch in range(num_epochs):
for batch in dataloader:
x, y, idx = batch # idx contains sample indices
# Your training code here
preds = model(x)
# Log per-sample metrics (e.g., preds)
logger.log_metric(epoch, idx, preds)
# Compute importance scores
# YourImportanceMethod must inherit the DCImportance class and implement the compute_importance function
importance_computer = YourImportanceMethod(...)
importance_scores = importance_computer.compute_importance()
# Resample dataset based on importance
dataloader.resample(importance_scores)
Core Classes
DCDataset
A wrapper around PyTorch datasets that appends the sample index to each batch:
from datacull import DCDataset
wrapped_dataset = DCDataset(your_dataset)
# Batch now returns: (*original_outputs, sample_index)
DCDataLoader
A customizable DataLoader supporting both dynamic and static sample pruning with importance scores.
__init__(self, dataset: DCDataset, pruning_rate: float, static: bool, **kwargs)
- datatset:
DCDatasetAny Pytorch dataset wrapped with the DCDataset class. - prunting_rate:
float (0,1)The fraction of samples to remove. - static:
boolThis variable decides whether to resample a new subset during training (dynamic mode) or resample only once (static mode).
# This function needs to be implemented by the user when creating their own pruning algorithm
# It holds the pruning logic
# And, returns a list of indices (a subset) which determines which samples to keep
compute_subset(self, sample_importance: list)
- sample_importance:
listA list containing an importance score corresponding to each sample in the dataset.
# This function calls compute_subset
# It also determines whether to sample once (static) or more (dynamic)
resample(self, sample_importance: list)
- sample_importance:
listA list containing an importance score corresponding to each sample in the dataset.
Example Usage
# Implement compute_subset() to define your pruning strategy
class MyPruner(DCDataLoader):
def compute_subset(self, sample_importance):
# Write pruning logic using or not using sample_importance
# Return indices of samples to keep
return indices_to_keep
# create the data loader
my_data_loader = MyPruner(DCDataset(my_dataset), batch_size)
# Select a new subset
# Here, we assume that your pruning logic does not require importance scores
my_data_loader.resample(None)
DCLogger
Efficiently logs per-sample metrics across training epochs.
__init__(self, trajectory_dir: str, save_every_k_epoch: int=1)
- trajectory_dir:
stringThe directory where a model's training metrics will be stored. - save_every_k_epoch:
int (default=1)Save metrics every k epochs.
# This function needs to be called to save a given metric during training
log_metric(self, epoch: int, sample_idx: torch.Tensor, metric: torch.Tensor)
- epoch:
intThe current epoch number. - sample_idx:
torch.TensorA batch of indices (provided automatically by the DCDataset class). - metric:
torch.TensorA batch of metrics to log such as predictions or loss.
Example Usage
logger = DCLogger(trajectory_dir="./trajectories/", save_every_k_epoch=2)
# During training
logger.log_metric(epoch, sample_indices, loss_values)
# Creates: ./trajectories/epoch{E}.jsonl
DCImportance
Base class for computing importance scores from logged trajectories.
__init__(self, dataset: DCDataset, window_size: int, logger_object: DCLogger, flush: bool = False)
- dataset:
DCDatasetA Pytorch dataset wrapped by the DCDataset class. - window_size:
intDetermines the number of consecutive epochs to extract. - loggret_object:
DCLoggerA DCLogger object to determine the logging directory and which epochs have been saved. - flush:
bool (default False)A boolean variable that determines whether to delete the metrics (trajectory segment) that have been currently read into memory, from disk (useful for dynamic methods)
# Returns the segment `start_epoch:start_epoch + window_size` from the trajectory
extract_trajectory_segment(self, start_epoch: int)
- start_epoch:
intDetermines the point in the trajectory to extract the current segment from.
# This function needs to be implemented by the user
# Returns a list containing the importance score for each sample.
compute_importance(self)
Example Usage
# Create your sample importance class
class YourImportanceMethod(DCImportance):
def compute_importance():
for epoch in range(max_epochs - window_size + 1)
segment = self.extract_trajectory_segment(epoch)
# Write your sample importance logic here
return sample_importance
# Create you sample importance object
importance_object = YourImportanceMethod(dataset=dataset, window_size=5, logger_object=logger)
importance_scores = importance_object.compute_importance()
Available Methods
AUM (Area Under the Margin)
Identifies easy-to-learn samples by computing the margin between true class logits and max other class logits.
Class: AUMImportance from datacull.methods.CCS
Example Usage
from datacull.methods.CCS import AUMImportance
importance = AUMImportance(dataset=dataset, trajectory_length=num_epochs, logger_object=logger)
scores = importance.compute_importance()
CCS (Coverage-centric Coreset Selection)
Uses AUM scores with stratified sampling to maintain dataset diversity at high pruning rates.
Class: CCSDataLoader from datacull.methods.CCS
Example Usage (for a complete working example using AUM, click here)
train_dataloader = CCSDataLoader(dataset=train_set, pruning_rate=0.3, beta=0.1, num_strata=50, descending=False, batch_size=128, num_workers=1)
train_dataloader.resample(scores)
TDDS (Temporal Dual-Depth Scoring)
Leverages temporal stability of predictions across epochs.
Classes: TDDSImportance, TDDSDataLoader from datacull.methods.TDDS
Example Usage (for a complete working example, click here)
from datacull.methods.TDDS import TDDSImportance
importance_object = TDDSImportance(dataset=dataset, trajectory_length=num_epochs, window_size=5, decay=0.9, logger_object=logger)
scores = importance_object.compute_importance()
train_dataloader = TDDSDataLoader(dataset=train_set, pruning_rate=0.3, batch_size=128, num_workers=1)
train_dataloader.resample(scores)
MetriQ
Class-balanced pruning, inversely proportional to per-class validation accuracy.
Class: MetriQDataLoader from datacull.methods.MetriQ
Example Usage (for a complete working example, click here)
from datacull.methods.MetriQ import MetriQDataLoader
# Requires validation accuracy per class
class_wise_acc = np.array([0.95, 0.80, 0.88])
train_dataloader = MetriQDataLoader(dataset=dataset, pruning_rate=0.3, class_wise_acc=class_wise_acc, batch_size=64, num_workers=1)
train_dataloader.resample(None)
RS2 (Repeated Random Sampling)
Fast random sampling with optional stratification for class balance.
Class: RS2DataLoader from datacull.methods.RS2
Example Usage (for a complete working example, click here)
from datacull.methods.RS2 import RS2DataLoader
dataloader = RS2DataLoader(dataset=dataset, pruning_rate=0.3, sampling_with_replacement=False, stratify=False, batch_size=64, num_workers=1)
train_dataloader.resample(None)
RCAP (Relative Class-aware Adaptive Pruning)
Dynamic class-aware probabilistic sampling using loss-based importance scores.
Classes: RCAPImportance, RCAPDataLoader from datacull.methods.RCAP
Example Usage (for a complete working example, click here)
from datacull.methods.RCAP import RCAPImportance, RCAPDataLoader
importance_object = RCAPImportance(dataset=dataset, logger_object=logger, beta=2.0, clipping_threshold=None)
train_dataloader = RCAPDataLoader(dataset=dataset, pruning_rate=0.3, batch_size=64, num_workers=1)
train_dataloader.resample(importance_object.compute_importance())
An example of using separate importance and sampling techniques
from datacull.methods.TDDS import TDDSImportance
importance_object = TDDSImportance(dataset=dataset, trajectory_length=num_epochs, window_size=5, decay=0.9, logger_object=logger)
scores = importance_object.compute_importance()
train_dataloader = CCSDataLoader(dataset=train_set, pruning_rate=0.3, beta=0.1, num_strata=50, descending=False, batch_size=128, num_workers=1)
train_dataloader.resample(scores)
Custom Pruning Strategy Example
import numpy as np
from datacull import DCDataLoader
class RandomPruner(DCDataLoader):
"""Simple random pruning baseline"""
def compute_subset(self, sample_importance):
# Randomly select samples to keep
indices = np.arange(self.total_num_samples)
np.random.shuffle(indices)
return indices[:self.required_num_samples].tolist()
# Use it
pruner = RandomPruner(dataset, pruning_rate=0.3, batch_size=64)
pruner.resample(None)
Future Ideas
- Pytorch specific exmaples
- Implement more data pruning algorithms
Citation
If you use DataCull in your research, please cite it as:
@inproceedings{hassanrcap,
title={RCAP: Robust, Class-Aware, Probabilistic Dynamic Dataset Pruning},
author={Hassan, Atif and Khare, Swanand and Paik, Jiaul H},
booktitle={The 41st Conference on Uncertainty in Artificial Intelligence}
}
Alternatively, use the following DBLP Bibtex link
Happy pruning! 🌱
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file datacull-1.0.2.tar.gz.
File metadata
- Download URL: datacull-1.0.2.tar.gz
- Upload date:
- Size: 19.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e94ac078d28281372e14d1625bb92d949d69c4c492d8534c71591d5fc95ddca9
|
|
| MD5 |
a001bb2038ba32f9ec138f4a316c841c
|
|
| BLAKE2b-256 |
a1e532ee4e80d4768abdb5e9cebdf7d0162f9043c0a7a4ef3cb0e1226c85ecdb
|
File details
Details for the file datacull-1.0.2-py3-none-any.whl.
File metadata
- Download URL: datacull-1.0.2-py3-none-any.whl
- Upload date:
- Size: 19.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
95860316b7ff460e8c05b670d9c4fb51e0ad63113d974ec7b72a2ad417ec508c
|
|
| MD5 |
f5751ebc780988def52709809ca80a48
|
|
| BLAKE2b-256 |
c73f7315c029a2fe030ea9beb89324b23c031be4ca6c2fbba94005310c3d2a11
|