Skip to main content

A lightweight, extensible Python library for data pruning with Hugging Face datasets and transformers

Project description

🌿 dPrune: A Framework for Data Pruning

CI PyPI version License: MIT

dPrune is a lightweight, extensible Python library designed to make data selection and pruning simple and accessible for NLP and speech tasks, with first-class support for Hugging Face datasets and transformers.

Data pruning is the process of selecting a smaller, more informative, and a higher quality subset of a large training dataset. This can lead to faster training, lower computational costs, and even better model performance by removing noisy or redundant examples. dPrune provides a modular framework to experiment with various pruning strategies.


⭐ Key Features

  • Hugging Face Integration: Works seamlessly with huggingface datasets and transformers.
  • Modular Design: Separates the scoring logic from the pruning criteria.
  • Extensible: Easily create your own custom scoring functions and pruning methods.
  • Supervised & Unsupervised Scoring Methods: Includes a variety of common pruning techniques.
    • Supervised: Score data based on model outputs (e.g., cross-entropy loss, forgetting scores).
    • Unsupervised: Score data based on intrinsic properties (e.g., clustering embeddings, perplexity scores).
  • Multiple Pruning Strategies: Supports top/bottom-k pruning, stratified sampling, and random pruning.

📦 Installation

You can install dPrune via pip:

pip install dprune

Alternatively, you can use uv:

uv pip install dprune

To install the library with all testing dependencies, run:

pip install "dprune[test]"

🚀 Quick Start

Here's a simple example of how to prune a dataset using unsupervised KMeans clustering. This approach keeps the most representative examples (closest to cluster centroids) without requiring labels or fine-tuning.

from datasets import Dataset
from dprune import PruningPipeline, KMeansCentroidDistanceScorer, BottomKPruner

data = {'text': ['A great movie!', 'Waste of time.', 'Amazing.', 'So predictable.']}
raw_dataset = Dataset.from_dict(data)

scorer = KMeansCentroidDistanceScorer(
    model=model,
    tokenizer=tokenizer,
    text_column='text',
    num_clusters=2
)
pruner = BottomKPruner(k=0.5)

pipeline = PruningPipeline(scorer=scorer, pruner=pruner)
pruned_dataset = pipeline.run(raw_dataset)

print(f"Original dataset size: {len(raw_dataset)}")
print(f"Pruned dataset size: {len(pruned_dataset)}")

💡 Core Concepts

dPrune is built around three core components:

Scorer

A Scorer takes a Dataset and adds a new score column to it. The score is a numerical value that represents some property of the example (e.g., how hard it is for the model to classify).

Pruner

A Pruner takes a scored Dataset and selects a subset of it based on the score column.

PruningPipeline

The PruningPipeline is a convenience wrapper that chains a Scorer and a Pruner together into a single, easy-to-use workflow.

🛠️ Available Components

Scorers

  • KMeansCentroidDistanceScorer: (Unsupervised) Embeds the data, performs k-means clustering, and scores each example by its distance to its cluster centroid.
  • PerplexityScorer: (Unsupervised) Calculates perplexity score for each example using the KenLM n-gram language model.
  • CrossEntropyScorer: (Supervised) Scores examples based on the cross-entropy loss from a given model.
  • ForgettingScorer: (Supervised) Works with a ForgettingCallback to score examples based on how many times they are "forgotten" during training.
  • ...many more coming soon!

Pruners

  • TopKPruner: Selects the k examples with the highest scores.
  • BottomKPruner: Selects the k examples with the lowest scores.
  • StratifiedPruner: Divides the data into strata based on score quantiles and samples proportionally from each.
  • RandomPruner: Randomly selects k examples, ignoring scores. Useful for establishing a baseline.

Callbacks

  • ForgettingCallback: A TrainerCallback that records learning events during training to be used with the ForgettingScorer.

🎨 Extending dPrune

Creating your own custom components is straightforward.

Custom Scorer

Simply inherit from the Scorer base class and implement the score method.

from dprune import Scorer
from datasets import Dataset
import random

class RandomScorer(Scorer):
    def score(self, dataset: Dataset, **kwargs) -> Dataset:
        scores = [random.random() for _ in range(len(dataset))]
        return dataset.add_column("score", scores)

Custom Pruner

Inherit from the Pruner base class and implement the prune method.

from dprune import Pruner
from datasets import Dataset

class ThresholdPruner(Pruner):
    def __init__(self, threshold: float):
        self.threshold = threshold

    def prune(self, scored_dataset: Dataset, **kwargs) -> Dataset:
        indices_to_keep = [i for i, score in enumerate(scored_dataset['score']) if score > self.threshold]
        return scored_dataset.select(indices_to_keep)

📓 Example Notebooks

1. Supervised Pruning with Forgetting Score

examples/supervised_pruning_with_forgetting_score.ipynb

Shows how to use forgetting scores to prune dataset.

2. Unsupervised Pruning with K-Means

examples/unsupervised_pruning_with_kmeans.ipynb

Demonstrates clustering-based pruning using K-means to remove outlier examples.

3. Unsupervised Pruning with Perplexity

examples/unsupervised_pruning_with_perplexity.ipynb

Shows how to use perplexity scoring for data pruning in text summarization.

🎓 Advanced Usage: Forgetting Score

Some pruning strategies require observing the model's behavior during training. dPrune supports this via Hugging Face TrainerCallback. Here is how you would use the ForgettingScorer:

from dprune import ForgettingCallback, ForgettingScorer

# 1. Initialize the callback and trainer
forgetting_callback = ForgettingCallback()
trainer = Trainer(
    model=model,
    train_dataset=raw_dataset,
    callbacks=[forgetting_callback],
)

# 2. Assign the trainer to the callback
forgetting_callback.trainer = trainer

# 3. Train the model. The callback will record events automatically.
trainer.train()

# 4. Create the scorer from the populated callback
scorer = ForgettingScorer(forgetting_callback)

# 5. Use the scorer in a pipeline as usual
pipeline = PruningPipeline(scorer=scorer, pruner=TopKPruner(k=0.8)) # Keep 80%
pruned_dataset = pipeline.run(raw_dataset)

print(f"Pruned with forgetting scores, final size: {len(pruned_dataset)}")

🧪 Running Tests

To run the full test suite, clone the repository and run pytest from the root directory:

git clone https://github.com/ahazeemi/dPrune.git
cd dPrune
# Install in editable mode with test dependencies
pip install -e ".[test]"
# Or, with uv
uv pip install -e ".[test]"

pytest

🤝 Contributing

Contributions are welcome! If you have a feature request, bug report, or want to add a new scorer or pruner, please open an issue or submit a pull request on GitHub.

📄 License

This project is licensed under the MIT License. See the LICENSE file for details.

📝 Citation

If you use dPrune in your research, please cite it as follows:

@software{dprune2025,
  author = {Azeemi, Abdul Hameed and Qazi, Ihsan Ayyub and Raza, Agha Ali},
  title = {dPrune: A Framework for Data Pruning},
  year = {2025},
  url = {https://github.com/ahazeemi/dPrune}
}

Alternatively, you can cite it in text as:

Abdul Hameed Azeemi, Ihsan Ayyub Qazi, and Agha Ali Raza. (2025). dPrune: A Framework for Data Pruning. https://github.com/ahazeemi/dPrune

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dprune-0.0.1.tar.gz (21.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dprune-0.0.1-py3-none-any.whl (16.2 kB view details)

Uploaded Python 3

File details

Details for the file dprune-0.0.1.tar.gz.

File metadata

  • Download URL: dprune-0.0.1.tar.gz
  • Upload date:
  • Size: 21.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.10

File hashes

Hashes for dprune-0.0.1.tar.gz
Algorithm Hash digest
SHA256 b1e73ee579cb2074e8867038c540638ea17e28197aa4025d542d366d1d8ae770
MD5 4b2f7c53d81678082f7823c53d52cac0
BLAKE2b-256 fb02eb6e5b19282c927acbdefa871723ad161e03c2d60ef216f00e70f5f985a2

See more details on using hashes here.

File details

Details for the file dprune-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: dprune-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 16.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.10

File hashes

Hashes for dprune-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 383973c6755b0cd13223bd199b09c8c2d865e6d51a53ac883a23afc0dd857902
MD5 e9e80c5cb3d1bc6310c71e4fddf458a2
BLAKE2b-256 3b1b8fe68a63d8a9d93bafcc7738c3c750796ea38c6bab5bd0bd950613747eb2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page