Skip to main content

PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning

Project description

Prototype-Enhanced Aligned Representation Learning

Examples

Text Classification with BERT

from transformers import AutoTokenizer, AutoModel
from pearl import PEARLPipeline
import torch

# Extract BERT embeddings
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')

def get_embeddings(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :].numpy()

# Get embeddings
X_train = get_embeddings(train_texts)
X_test = get_embeddings(test_texts)

# Apply PEARL
pearl = PEARLPipeline(n_classes=num_classes, device='cuda')
pearl.fit(X_train, y_train)

X_train_enhanced = pearl.transform(X_train, mode='paf')
X_test_enhanced = pearl.transform(X_test, mode='paf')

# Train classifier
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression()
clf.fit(X_train_enhanced, y_train)
accuracy = clf.score(X_test_enhanced, y_test)

See the examples/ directory for complete working examples:

Performance

PEARL consistently improves classification performance across multiple benchmarks:

Dataset Classifier Raw F1 PEARL F1 Improvement
AG News Logistic 0.8542 0.8876 +3.34%
AG News SVM 0.8621 0.8912 +2.91%
AG News MLP 0.8698 0.9045 +3.47%
AG News RAG 0.8823 0.9156 +3.33%

Results show consistent improvements across different classifiers and datasets.

API Reference

PEARLPipeline

Main interface for PEARL.

Methods:

  • fit(X_train, y_train, X_val, y_val, **kwargs): Train the pipeline
  • transform(X, mode='paf'): Transform embeddings
  • fit_transform(X, y, **kwargs): Fit and transform in one step
  • save(path): Save pipeline to disk
  • load(path, device): Load pipeline from disk (class method)

SignalExtractor

Neural network for signal extraction.

Methods:

  • forward(x): Forward pass returning all outputs
  • get_enhanced_embedding(x): Extract enhanced embedding

PrototypeFeatures

Prototype-based feature generator.

Methods:

  • fit(embeddings, labels): Learn prototypes from training data
  • transform(embeddings): Generate prototype features
  • get_augmented(embeddings): Get embeddings + features

RAGClassifierWrapper

Retrieval-augmented classifier.

Methods:

  • fit(X_train, y_train, X_val, y_val, **kwargs): Train the model
  • predict(X): Predict class labels
  • predict_proba(X): Predict class probabilities

Requirements

  • Python >= 3.8
  • PyTorch >= 1.12.0
  • NumPy >= 1.20.0
  • scikit-learn >= 1.0.0
  • pandas >= 1.3.0
  • openpyxl >= 3.0.0

Citation

If you use PEARL in your research, please cite:

@software{pearl2024,
  title={PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning},
  author={PEARL Contributors},
  year={2024},
  url={https://github.com/yourusername/pearl}
}

Contributing

Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.

Development Setup

git clone https://github.com/yourusername/pearl.git
cd pearl
pip install -e ".[dev]"

# Run tests
pytest tests/

# Format code
black pearl/

# Type checking
mypy pearl/

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

PEARL was developed to address the challenge of enhancing learned embeddings for downstream classification tasks. It builds on ideas from:

  • Signal processing and denoising
  • Prototype-based learning
  • Multi-task learning
  • Retrieval-augmented generation

Support

Roadmap

  • Support for additional embedding types (images, audio)
  • Pre-trained models for common datasets
  • Integration with popular frameworks (HuggingFace, PyTorch Lightning)
  • Online/incremental learning support
  • Multi-label classification support
  • Comprehensive benchmarking suite

Made with ❤️ by the PEARL team

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pearl_h-0.1.1.tar.gz (19.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pearl_h-0.1.1-py3-none-any.whl (21.0 kB view details)

Uploaded Python 3

File details

Details for the file pearl_h-0.1.1.tar.gz.

File metadata

  • Download URL: pearl_h-0.1.1.tar.gz
  • Upload date:
  • Size: 19.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.7

File hashes

Hashes for pearl_h-0.1.1.tar.gz
Algorithm Hash digest
SHA256 066579eb97cb71555cb20bdd6a23f672b97cc7cbd41f320ef06f4d5778d4df9c
MD5 5a6518568b58b73fb876008bfd8b0164
BLAKE2b-256 6a4c38ca44b2170b64358d6a5dd5bb154c9a1b990f40c46c09c486989668fcc6

See more details on using hashes here.

File details

Details for the file pearl_h-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: pearl_h-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 21.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.7

File hashes

Hashes for pearl_h-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d58855a2feb68cf83aa89937e9c639249a285370989c1d3b579083830bccf446
MD5 29ca72e49887802a2940abcfe030a2e8
BLAKE2b-256 7542ca8ff6f2e18fe6e3376d7440aa1168ea407e145bf31faaca226f4d747436

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page