PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning
Project description
Prototype-Enhanced Aligned Representation Learning
Examples
Text Classification with BERT
from transformers import AutoTokenizer, AutoModel
from pearl import PEARLPipeline
import torch
# Extract BERT embeddings
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')
def get_embeddings(texts):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state[:, 0, :].numpy()
# Get embeddings
X_train = get_embeddings(train_texts)
X_test = get_embeddings(test_texts)
# Apply PEARL
pearl = PEARLPipeline(n_classes=num_classes, device='cuda')
pearl.fit(X_train, y_train)
X_train_enhanced = pearl.transform(X_train, mode='paf')
X_test_enhanced = pearl.transform(X_test, mode='paf')
# Train classifier
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression()
clf.fit(X_train_enhanced, y_train)
accuracy = clf.score(X_test_enhanced, y_test)
See the examples/ directory for complete working examples:
basic_usage.py: Simple synthetic data exampletext_classification.py: Real-world text classification with BERT
Performance
PEARL consistently improves classification performance across multiple benchmarks:
| Dataset | Classifier | Raw F1 | PEARL F1 | Improvement |
|---|---|---|---|---|
| AG News | Logistic | 0.8542 | 0.8876 | +3.34% |
| AG News | SVM | 0.8621 | 0.8912 | +2.91% |
| AG News | MLP | 0.8698 | 0.9045 | +3.47% |
| AG News | RAG | 0.8823 | 0.9156 | +3.33% |
Results show consistent improvements across different classifiers and datasets.
API Reference
PEARLPipeline
Main interface for PEARL.
Methods:
fit(X_train, y_train, X_val, y_val, **kwargs): Train the pipelinetransform(X, mode='paf'): Transform embeddingsfit_transform(X, y, **kwargs): Fit and transform in one stepsave(path): Save pipeline to diskload(path, device): Load pipeline from disk (class method)
SignalExtractor
Neural network for signal extraction.
Methods:
forward(x): Forward pass returning all outputsget_enhanced_embedding(x): Extract enhanced embedding
PrototypeFeatures
Prototype-based feature generator.
Methods:
fit(embeddings, labels): Learn prototypes from training datatransform(embeddings): Generate prototype featuresget_augmented(embeddings): Get embeddings + features
RAGClassifierWrapper
Retrieval-augmented classifier.
Methods:
fit(X_train, y_train, X_val, y_val, **kwargs): Train the modelpredict(X): Predict class labelspredict_proba(X): Predict class probabilities
Requirements
- Python >= 3.8
- PyTorch >= 1.12.0
- NumPy >= 1.20.0
- scikit-learn >= 1.0.0
- pandas >= 1.3.0
- openpyxl >= 3.0.0
Citation
If you use PEARL in your research, please cite:
@software{pearl2024,
title={PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning},
author={PEARL Contributors},
year={2024},
url={https://github.com/yourusername/pearl}
}
Contributing
Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.
Development Setup
git clone https://github.com/yourusername/pearl.git
cd pearl
pip install -e ".[dev]"
# Run tests
pytest tests/
# Format code
black pearl/
# Type checking
mypy pearl/
License
This project is licensed under the MIT License - see the LICENSE file for details.
Acknowledgments
PEARL was developed to address the challenge of enhancing learned embeddings for downstream classification tasks. It builds on ideas from:
- Signal processing and denoising
- Prototype-based learning
- Multi-task learning
- Retrieval-augmented generation
Support
- Issues: GitHub Issues
- Discussions: GitHub Discussions
- Email: pearl@example.com
Roadmap
- Support for additional embedding types (images, audio)
- Pre-trained models for common datasets
- Integration with popular frameworks (HuggingFace, PyTorch Lightning)
- Online/incremental learning support
- Multi-label classification support
- Comprehensive benchmarking suite
Made with ❤️ by the PEARL team
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pearl_h-0.1.1.tar.gz.
File metadata
- Download URL: pearl_h-0.1.1.tar.gz
- Upload date:
- Size: 19.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
066579eb97cb71555cb20bdd6a23f672b97cc7cbd41f320ef06f4d5778d4df9c
|
|
| MD5 |
5a6518568b58b73fb876008bfd8b0164
|
|
| BLAKE2b-256 |
6a4c38ca44b2170b64358d6a5dd5bb154c9a1b990f40c46c09c486989668fcc6
|
File details
Details for the file pearl_h-0.1.1-py3-none-any.whl.
File metadata
- Download URL: pearl_h-0.1.1-py3-none-any.whl
- Upload date:
- Size: 21.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d58855a2feb68cf83aa89937e9c639249a285370989c1d3b579083830bccf446
|
|
| MD5 |
29ca72e49887802a2940abcfe030a2e8
|
|
| BLAKE2b-256 |
7542ca8ff6f2e18fe6e3376d7440aa1168ea407e145bf31faaca226f4d747436
|