Skip to main content

PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning

Project description

PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning

PyPI version License: MIT Python 3.8+

PEARL is a powerful framework for enhancing embeddings through signal extraction and prototype-guided feature augmentation. It dramatically improves classification performance on embedding-based tasks by separating discriminative signal from noise and augmenting embeddings with prototype-based features.

Key Features

  • Signal Extraction: Separates discriminative signal from noise in embeddings using deep learning
  • Prototype-Guided Features (PAF): Augments embeddings with rich prototype-based similarity features
  • Easy-to-use API: Simple scikit-learn-like interface
  • Flexible: Works with any embedding (BERT, ResNet, custom embeddings, etc.)
  • Proven Results: Consistent improvements across multiple classifiers and datasets
  • GPU Accelerated: Built on PyTorch for fast training and inference

Installation

From PyPI (recommended)

pip install pearl-ai

From source

git clone https://github.com/yourusername/pearl.git
cd pearl
pip install -e .

Optional dependencies

For examples and advanced features:

pip install pearl-ai[examples]  # Install with example dependencies
pip install pearl-ai[dev]       # Install with development tools
pip install pearl-ai[all]       # Install everything

Quick Start

import numpy as np
from pearl import PEARLPipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

# Your embeddings and labels
X_train, y_train = ...  # Shape: [N, D], [N]
X_test, y_test = ...

# Initialize PEARL
pearl = PEARLPipeline(
    n_classes=10,
    device='cuda'  # or 'cpu'
)

# Fit PEARL on training data
pearl.fit(X_train, y_train)

# Transform embeddings
X_train_enhanced = pearl.transform(X_train, mode='enhanced')
X_test_enhanced = pearl.transform(X_test, mode='enhanced')

# Use with any classifier
clf = LogisticRegression()
clf.fit(X_train_enhanced, y_train)
pred = clf.predict(X_test_enhanced)

print(f"F1 Score: {f1_score(y_test, pred, average='macro'):.4f}")

How PEARL Works

PEARL enhances embeddings through two key steps:

1. Signal Extraction

The Signal Extractor learns to separate embeddings into:

  • Signal: Class-discriminative information
  • Noise: Non-discriminative variations

It uses a multi-task learning approach with:

  • Reconstruction loss (preserve information)
  • Centroid alignment loss (align with class centers)
  • Contrastive loss (separate classes)
  • Orthogonality loss (decorrelate signal and noise)

2. Prototype-Guided Augmentation (PAF)

PAF augments embeddings with rich features based on learned prototypes:

  • Maximum similarity to per-class prototypes
  • Mean similarity to per-class prototypes
  • Similarity to class centroids
  • Decision margin (confidence)
  • Prediction entropy (uncertainty)

These features provide powerful additional signal for downstream classifiers.

Transformation Modes

PEARL supports three transformation modes:

# Mode 1: Raw (no transformation)
X_raw = pearl.transform(X, mode='raw')

# Mode 2: Enhanced (signal extraction only)
X_enhanced = pearl.transform(X, mode='enhanced')

# Mode 3: PAF (enhanced + prototype features) - RECOMMENDED
X_paf = pearl.transform(X, mode='paf')

Advanced Usage

Custom Configuration

from pearl import PEARLPipeline

pearl = PEARLPipeline(
    n_classes=10,
    input_dim=768,              # Auto-detected if None
    signal_dim=256,             # Signal representation dimension
    hidden_dims=(512, 384),     # Hidden layers for encoder
    n_prototypes_per_class=3,   # Prototypes per class
    device='cuda',
    dropout=0.3,
    random_state=42
)

# Fine-tune training parameters
pearl.fit(
    X_train, y_train,
    X_val, y_val,
    lr=1e-3,
    weight_decay=1e-4,
    batch_size=128,
    epochs=100,
    patience=20,
    recon_weight=1.0,       # Reconstruction loss weight
    centroid_weight=2.0,    # Centroid loss weight
    contrast_weight=0.5,    # Contrastive loss weight
    ortho_weight=0.5,       # Orthogonality loss weight
    verbose=True
)

Save and Load Pipeline

# Save trained pipeline
pearl.save('./my_pearl_model')

# Load pipeline
from pearl import PEARLPipeline
pearl = PEARLPipeline.load('./my_pearl_model', device='cuda')

# Use immediately
X_enhanced = pearl.transform(X_test, mode='enhanced')

Using Individual Components

from pearl import SignalExtractorTrainer, PAFAugmentor, SignalExtractor

# 1. Signal Extraction
model = SignalExtractor(input_dim=768, signal_dim=256, n_classes=10)
trainer = SignalExtractorTrainer(model, device='cuda')
trainer.fit(X_train, y_train, X_val, y_val)
X_enhanced = trainer.transform(X_test)

# 2. PAF Features
paf = PAFAugmentor(n_classes=10, n_prototypes_per_class=3)
paf.fit(X_train, y_train)
X_paf = paf.transform(X_test)  # Augmented embeddings

Using with RAG Classifier

PEARL includes a powerful RAG (Retrieval-Augmented Generation) classifier:

from pearl import RAGClassifierWrapper

rag = RAGClassifierWrapper(
    embed_dim=768,
    n_classes=10,
    k=8,  # Number of neighbors to retrieve
    device='cuda'
)

rag.fit(X_train, y_train, X_val, y_val)
predictions = rag.predict(X_test)

Examples

Text Classification with BERT

from transformers import AutoTokenizer, AutoModel
from pearl import PEARLPipeline
import torch

# Extract BERT embeddings
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')

def get_embeddings(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :].numpy()

# Get embeddings
X_train = get_embeddings(train_texts)
X_test = get_embeddings(test_texts)

# Apply PEARL
pearl = PEARLPipeline(n_classes=num_classes, device='cuda')
pearl.fit(X_train, y_train)

X_train_enhanced = pearl.transform(X_train, mode='paf')
X_test_enhanced = pearl.transform(X_test, mode='paf')

# Train classifier
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression()
clf.fit(X_train_enhanced, y_train)
accuracy = clf.score(X_test_enhanced, y_test)

See the examples/ directory for complete working examples:

Performance

PEARL consistently improves classification performance across multiple benchmarks:

Dataset Classifier Raw F1 PEARL F1 Improvement
AG News Logistic 0.8542 0.8876 +3.34%
AG News SVM 0.8621 0.8912 +2.91%
AG News MLP 0.8698 0.9045 +3.47%
AG News RAG 0.8823 0.9156 +3.33%

Results show consistent improvements across different classifiers and datasets.

API Reference

PEARLPipeline

Main interface for PEARL.

Methods:

  • fit(X_train, y_train, X_val, y_val, **kwargs): Train the pipeline
  • transform(X, mode='paf'): Transform embeddings
  • fit_transform(X, y, **kwargs): Fit and transform in one step
  • save(path): Save pipeline to disk
  • load(path, device): Load pipeline from disk (class method)

SignalExtractor

Neural network for signal extraction.

Methods:

  • forward(x): Forward pass returning all outputs
  • get_enhanced_embedding(x): Extract enhanced embedding

PrototypeFeatures

Prototype-based feature generator.

Methods:

  • fit(embeddings, labels): Learn prototypes from training data
  • transform(embeddings): Generate prototype features
  • get_augmented(embeddings): Get embeddings + features

RAGClassifierWrapper

Retrieval-augmented classifier.

Methods:

  • fit(X_train, y_train, X_val, y_val, **kwargs): Train the model
  • predict(X): Predict class labels
  • predict_proba(X): Predict class probabilities

Requirements

  • Python >= 3.8
  • PyTorch >= 1.12.0
  • NumPy >= 1.20.0
  • scikit-learn >= 1.0.0
  • pandas >= 1.3.0
  • openpyxl >= 3.0.0

Citation

If you use PEARL in your research, please cite:

@software{pearl2024,
  title={PEARL: Prototype-guided Embedding Refinement via Adaptive Representation Learning},
  author={PEARL Contributors},
  year={2024},
  url={https://github.com/yourusername/pearl}
}

Contributing

Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.

Development Setup

git clone https://github.com/yourusername/pearl.git
cd pearl
pip install -e ".[dev]"

# Run tests
pytest tests/

# Format code
black pearl/

# Type checking
mypy pearl/

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

PEARL was developed to address the challenge of enhancing learned embeddings for downstream classification tasks. It builds on ideas from:

  • Signal processing and denoising
  • Prototype-based learning
  • Multi-task learning
  • Retrieval-augmented generation

Support

Roadmap

  • Support for additional embedding types (images, audio)
  • Pre-trained models for common datasets
  • Integration with popular frameworks (HuggingFace, PyTorch Lightning)
  • Online/incremental learning support
  • Multi-label classification support
  • Comprehensive benchmarking suite

Made with ❤️ by the PEARL team

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pearl_h-0.1.0.tar.gz (22.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pearl_h-0.1.0-py3-none-any.whl (22.7 kB view details)

Uploaded Python 3

File details

Details for the file pearl_h-0.1.0.tar.gz.

File metadata

  • Download URL: pearl_h-0.1.0.tar.gz
  • Upload date:
  • Size: 22.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.7

File hashes

Hashes for pearl_h-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c24564ee610fdb339de315a806897318186eaaf88a444722b42a16607585dd82
MD5 bdaeaa4d58db419d119151224adfdd33
BLAKE2b-256 fa221fedcc67b6b4f6c5d0248e372274ae32ba1ff2988bf71e1d0344b72c0097

See more details on using hashes here.

File details

Details for the file pearl_h-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pearl_h-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 22.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.7

File hashes

Hashes for pearl_h-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 630705507a683de7ea058b9b2ccfc302eaa64c15fafa52ddb173bc0bbf3e5701
MD5 d1beffe028952dc3b6f0eb3f900b901a
BLAKE2b-256 b3905a0672a1f468418090298915691073a4392807470d256737e0ef202c698a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page