PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning
Project description
PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning
PEARL is a powerful framework for enhancing embeddings through signal extraction and prototype-guided feature augmentation. It dramatically improves classification performance on embedding-based tasks by separating discriminative signal from noise and augmenting embeddings with prototype-based features.
Key Features
- Signal Extraction: Separates discriminative signal from noise in embeddings using deep learning
- Prototype-Guided Features (PAF): Augments embeddings with rich prototype-based similarity features
- Easy-to-use API: Simple scikit-learn-like interface
- Flexible: Works with any embedding (BERT, ResNet, custom embeddings, etc.)
- Proven Results: Consistent improvements across multiple classifiers and datasets
- GPU Accelerated: Built on PyTorch for fast training and inference
Installation
From PyPI (recommended)
pip install pearl-ai
From source
git clone https://github.com/yourusername/pearl.git
cd pearl
pip install -e .
Optional dependencies
For examples and advanced features:
pip install pearl-ai[examples] # Install with example dependencies
pip install pearl-ai[dev] # Install with development tools
pip install pearl-ai[all] # Install everything
Quick Start
import numpy as np
from pearl import PEARLPipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
# Your embeddings and labels
X_train, y_train = ... # Shape: [N, D], [N]
X_test, y_test = ...
# Initialize PEARL
pearl = PEARLPipeline(
n_classes=10,
device='cuda' # or 'cpu'
)
# Fit PEARL on training data
pearl.fit(X_train, y_train)
# Transform embeddings
X_train_enhanced = pearl.transform(X_train, mode='enhanced')
X_test_enhanced = pearl.transform(X_test, mode='enhanced')
# Use with any classifier
clf = LogisticRegression()
clf.fit(X_train_enhanced, y_train)
pred = clf.predict(X_test_enhanced)
print(f"F1 Score: {f1_score(y_test, pred, average='macro'):.4f}")
How PEARL Works
PEARL enhances embeddings through two key steps:
1. Signal Extraction
The Signal Extractor learns to separate embeddings into:
- Signal: Class-discriminative information
- Noise: Non-discriminative variations
It uses a multi-task learning approach with:
- Reconstruction loss (preserve information)
- Centroid alignment loss (align with class centers)
- Contrastive loss (separate classes)
- Orthogonality loss (decorrelate signal and noise)
2. Prototype-Guided Augmentation (PAF)
PAF augments embeddings with rich features based on learned prototypes:
- Maximum similarity to per-class prototypes
- Mean similarity to per-class prototypes
- Similarity to class centroids
- Decision margin (confidence)
- Prediction entropy (uncertainty)
These features provide powerful additional signal for downstream classifiers.
Transformation Modes
PEARL supports three transformation modes:
# Mode 1: Raw (no transformation)
X_raw = pearl.transform(X, mode='raw')
# Mode 2: Enhanced (signal extraction only)
X_enhanced = pearl.transform(X, mode='enhanced')
# Mode 3: PAF (enhanced + prototype features) - RECOMMENDED
X_paf = pearl.transform(X, mode='paf')
Advanced Usage
Custom Configuration
from pearl import PEARLPipeline
pearl = PEARLPipeline(
n_classes=10,
input_dim=768, # Auto-detected if None
signal_dim=256, # Signal representation dimension
hidden_dims=(512, 384), # Hidden layers for encoder
n_prototypes_per_class=3, # Prototypes per class
device='cuda',
dropout=0.3,
random_state=42
)
# Fine-tune training parameters
pearl.fit(
X_train, y_train,
X_val, y_val,
lr=1e-3,
weight_decay=1e-4,
batch_size=128,
epochs=100,
patience=20,
recon_weight=1.0, # Reconstruction loss weight
centroid_weight=2.0, # Centroid loss weight
contrast_weight=0.5, # Contrastive loss weight
ortho_weight=0.5, # Orthogonality loss weight
verbose=True
)
Save and Load Pipeline
# Save trained pipeline
pearl.save('./my_pearl_model')
# Load pipeline
from pearl import PEARLPipeline
pearl = PEARLPipeline.load('./my_pearl_model', device='cuda')
# Use immediately
X_enhanced = pearl.transform(X_test, mode='enhanced')
Using Individual Components
from pearl import SignalExtractorTrainer, PAFAugmentor, SignalExtractor
# 1. Signal Extraction
model = SignalExtractor(input_dim=768, signal_dim=256, n_classes=10)
trainer = SignalExtractorTrainer(model, device='cuda')
trainer.fit(X_train, y_train, X_val, y_val)
X_enhanced = trainer.transform(X_test)
# 2. PAF Features
paf = PAFAugmentor(n_classes=10, n_prototypes_per_class=3)
paf.fit(X_train, y_train)
X_paf = paf.transform(X_test) # Augmented embeddings
Using with RAG Classifier
PEARL includes a powerful RAG (Retrieval-Augmented Generation) classifier:
from pearl import RAGClassifierWrapper
rag = RAGClassifierWrapper(
embed_dim=768,
n_classes=10,
k=8, # Number of neighbors to retrieve
device='cuda'
)
rag.fit(X_train, y_train, X_val, y_val)
predictions = rag.predict(X_test)
Examples
Text Classification with BERT
from transformers import AutoTokenizer, AutoModel
from pearl import PEARLPipeline
import torch
# Extract BERT embeddings
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')
def get_embeddings(texts):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state[:, 0, :].numpy()
# Get embeddings
X_train = get_embeddings(train_texts)
X_test = get_embeddings(test_texts)
# Apply PEARL
pearl = PEARLPipeline(n_classes=num_classes, device='cuda')
pearl.fit(X_train, y_train)
X_train_enhanced = pearl.transform(X_train, mode='paf')
X_test_enhanced = pearl.transform(X_test, mode='paf')
# Train classifier
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression()
clf.fit(X_train_enhanced, y_train)
accuracy = clf.score(X_test_enhanced, y_test)
See the examples/ directory for complete working examples:
basic_usage.py: Simple synthetic data exampletext_classification.py: Real-world text classification with BERT
Performance
PEARL consistently improves classification performance across multiple benchmarks:
| Dataset | Classifier | Raw F1 | PEARL F1 | Improvement |
|---|---|---|---|---|
| AG News | Logistic | 0.8542 | 0.8876 | +3.34% |
| AG News | SVM | 0.8621 | 0.8912 | +2.91% |
| AG News | MLP | 0.8698 | 0.9045 | +3.47% |
| AG News | RAG | 0.8823 | 0.9156 | +3.33% |
Results show consistent improvements across different classifiers and datasets.
API Reference
PEARLPipeline
Main interface for PEARL.
Methods:
fit(X_train, y_train, X_val, y_val, **kwargs): Train the pipelinetransform(X, mode='paf'): Transform embeddingsfit_transform(X, y, **kwargs): Fit and transform in one stepsave(path): Save pipeline to diskload(path, device): Load pipeline from disk (class method)
SignalExtractor
Neural network for signal extraction.
Methods:
forward(x): Forward pass returning all outputsget_enhanced_embedding(x): Extract enhanced embedding
PrototypeFeatures
Prototype-based feature generator.
Methods:
fit(embeddings, labels): Learn prototypes from training datatransform(embeddings): Generate prototype featuresget_augmented(embeddings): Get embeddings + features
RAGClassifierWrapper
Retrieval-augmented classifier.
Methods:
fit(X_train, y_train, X_val, y_val, **kwargs): Train the modelpredict(X): Predict class labelspredict_proba(X): Predict class probabilities
Requirements
- Python >= 3.8
- PyTorch >= 1.12.0
- NumPy >= 1.20.0
- scikit-learn >= 1.0.0
- pandas >= 1.3.0
- openpyxl >= 3.0.0
Citation
If you use PEARL in your research, please cite:
@software{pearl2024,
title={PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning},
author={PEARL Contributors},
year={2024},
url={https://github.com/yourusername/pearl}
}
Contributing
Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.
Development Setup
git clone https://github.com/yourusername/pearl.git
cd pearl
pip install -e ".[dev]"
# Run tests
pytest tests/
# Format code
black pearl/
# Type checking
mypy pearl/
License
This project is licensed under the MIT License - see the LICENSE file for details.
Acknowledgments
PEARL was developed to address the challenge of enhancing learned embeddings for downstream classification tasks. It builds on ideas from:
- Signal processing and denoising
- Prototype-based learning
- Multi-task learning
- Retrieval-augmented generation
Support
- Issues: GitHub Issues
- Discussions: GitHub Discussions
- Email: pearl@example.com
Roadmap
- Support for additional embedding types (images, audio)
- Pre-trained models for common datasets
- Integration with popular frameworks (HuggingFace, PyTorch Lightning)
- Online/incremental learning support
- Multi-label classification support
- Comprehensive benchmarking suite
Made with ❤️ by the PEARL team
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pearl_h-0.1.0.tar.gz.
File metadata
- Download URL: pearl_h-0.1.0.tar.gz
- Upload date:
- Size: 22.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c24564ee610fdb339de315a806897318186eaaf88a444722b42a16607585dd82
|
|
| MD5 |
bdaeaa4d58db419d119151224adfdd33
|
|
| BLAKE2b-256 |
fa221fedcc67b6b4f6c5d0248e372274ae32ba1ff2988bf71e1d0344b72c0097
|
File details
Details for the file pearl_h-0.1.0-py3-none-any.whl.
File metadata
- Download URL: pearl_h-0.1.0-py3-none-any.whl
- Upload date:
- Size: 22.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
630705507a683de7ea058b9b2ccfc302eaa64c15fafa52ddb173bc0bbf3e5701
|
|
| MD5 |
d1beffe028952dc3b6f0eb3f900b901a
|
|
| BLAKE2b-256 |
b3905a0672a1f468418090298915691073a4392807470d256737e0ef202c698a
|