Joint-Embedding Predictive Architecture for Self-Supervised Learning
Project description
JEPA Framework
A powerful self-supervised learning framework for Joint-Embedding Predictive Architecture (JEPA)
Installation • Quick Start • Documentation • Examples • Contributing
🚀 Overview
JEPA (Joint-Embedding Predictive Architecture) is a cutting-edge self-supervised learning framework that learns rich representations by predicting parts of the input from other parts. This implementation provides a flexible, production-ready framework for training JEPA models across multiple modalities.
Key Features
🔧 Modular Design
- Flexible encoder-predictor architecture
- Support for any PyTorch model as encoder/predictor
- Easy to extend and customize for your specific needs
🌍 Multi-Modal Support
- Computer Vision: Images, videos, medical imaging
- Natural Language Processing: Text, documents, code
- Time Series: Sequential data, forecasting, anomaly detection
- Audio: Speech, music, environmental sounds
- Multimodal: Vision-language, audio-visual learning
⚡ High Performance
- Mixed precision training (FP16/BF16)
- Native DistributedDataParallel (DDP) support
- Memory-efficient implementations
- Optimized for both research and production
📊 Comprehensive Logging
- Weights & Biases integration
- TensorBoard support
- Console logging with rich formatting
- Multi-backend logging system
🎛️ Production Ready
- CLI interface for easy deployment
- Flexible YAML configuration system
- Comprehensive testing suite
- Docker support and containerization
- Type hints throughout
🏗️ Architecture
JEPA follows a simple yet powerful architecture:
graph LR
A[Input Data] --> B[Context/Target Split]
B --> C[Encoder]
C --> D[Joint Embedding Space]
D --> E[Predictor]
E --> F[Target Prediction]
F --> G[Loss Computation]
The model learns by:
- Splitting input into context and target regions
- Encoding both context and target separately
- Predicting target embeddings from context embeddings
- Learning representations that capture meaningful relationships
📦 Installation
From PyPI (Recommended)
pip install jepa
From Source
git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
pip install -e .
Development Installation
git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
pip install -e ".[dev,docs]"
Docker
docker pull dipsivenkatesh/jepa:latest
docker run -it dipsivenkatesh/jepa:latest
🚀 Quick Start
Python API
import torch
from torch.utils.data import DataLoader, TensorDataset
from jepa.models import JEPA
from jepa.models.encoder import Encoder
from jepa.models.predictor import Predictor
from jepa.trainer import create_trainer
# Toy dataset of (state_t, state_t1) pairs
dataset = TensorDataset(torch.randn(256, 16, 128), torch.randn(256, 16, 128))
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)
# Build model components
encoder = Encoder(hidden_dim=128)
predictor = Predictor(hidden_dim=128)
model = JEPA(encoder=encoder, predictor=predictor)
# Trainer with sensible defaults
trainer = create_trainer(model, learning_rate=3e-4, device="auto")
# Train for a couple of epochs
trainer.train(train_loader, num_epochs=2)
# Optional: stream metrics to Weights & Biases
trainer_ddp = create_trainer(
model,
learning_rate=3e-4,
device="auto",
logger="wandb",
logger_project="jepa-experiments",
logger_run_name="quickstart-run",
)
# Persist weights for downstream inference
model.save_pretrained("artifacts/jepa-small")
# Reload using the same model class
reloaded = JEPA.from_pretrained("artifacts/jepa-small", encoder=encoder, predictor=predictor)
Distributed Training (DDP)
Launch multi-GPU jobs with PyTorch's launcher:
torchrun --nproc_per_node=4 scripts/train.py --config config/default_config.yaml
Inside your training script, enable DDP when you create the trainer:
trainer = create_trainer(
model,
distributed=True,
world_size=int(os.environ["WORLD_SIZE"]),
local_rank=int(os.environ.get("LOCAL_RANK", 0)),
)
The trainer wraps the model in DistributedDataParallel, synchronizes losses, and restricts logging/checkpointing to rank zero automatically.
Action-Conditioned Variant
Use JEPAAction when actions influence the next state. Provide a state encoder, an action encoder, and a predictor that consumes the concatenated [z_t, a_t] embedding.
from jepa import JEPAAction
import torch.nn as nn
state_dim = 512
action_dim = 64
# Example encoders (replace with your own)
state_encoder = nn.Sequential(
nn.Flatten(),
nn.Linear(784, state_dim),
)
action_encoder = nn.Sequential(
nn.Linear(10, 128), nn.ReLU(),
nn.Linear(128, action_dim),
)
# Predictor takes [state_dim + action_dim] → state_dim
predictor = nn.Sequential(
nn.Linear(state_dim + action_dim, 512), nn.ReLU(),
nn.Linear(512, state_dim),
)
model = JEPAAction(state_encoder, action_encoder, predictor)
Command Line Interface
# Train a model
jepa-train --config config/default_config.yaml
# Train with custom parameters
jepa-train --config config/vision_config.yaml \
--batch-size 64 \
--learning-rate 0.001 \
--num-epochs 100
# Evaluate a trained model
jepa-evaluate --config config/default_config.yaml \
--checkpoint checkpoints/best_model.pth
# Generate a configuration template
jepa-train --generate-config my_config.yaml
# Get help
jepa-train --help
Configuration
JEPA uses YAML configuration files for easy experiment management:
# config/my_experiment.yaml
model:
encoder_type: "transformer"
encoder_dim: 768
predictor_type: "mlp"
predictor_hidden_dim: 2048
training:
batch_size: 32
learning_rate: 0.0001
num_epochs: 100
warmup_epochs: 10
data:
train_data_path: "data/train"
val_data_path: "data/val"
sequence_length: 16
logging:
wandb:
enabled: true
project: "jepa-experiments"
tensorboard:
enabled: true
log_dir: "./tb_logs"
🎯 Use Cases
Computer Vision
- Image Classification: Pre-train backbones for downstream tasks
- Object Detection: Learn robust visual representations
- Medical Imaging: Analyze medical scans and imagery
- Satellite Imagery: Process large-scale geographic data
Natural Language Processing
- Language Models: Pre-train transformer architectures
- Document Understanding: Learn document-level representations
- Code Analysis: Understand code structure and semantics
- Cross-lingual Learning: Build multilingual representations
Time Series Analysis
- Forecasting: Pre-train models for prediction tasks
- Anomaly Detection: Learn normal patterns in sequential data
- Financial Modeling: Analyze market trends and patterns
- IoT Sensors: Process sensor data streams
Multimodal Learning
- Vision-Language: Combine images and text understanding
- Audio-Visual: Learn from synchronized audio and video
- Cross-Modal Retrieval: Search across different modalities
- Embodied AI: Integrate multiple sensor modalities
📚 Examples
Vision Example
from jepa import JEPADataset, JEPATrainer, load_config
import torch.nn as nn
# Custom vision encoder
class VisionEncoder(nn.Module):
def __init__(self, input_dim=3, hidden_dim=512):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(input_dim, 64, 7, 2, 3),
nn.ReLU(),
nn.Conv2d(64, 128, 3, 2, 1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(128, hidden_dim)
)
def forward(self, x):
return self.conv_layers(x)
# Load config and customize
config = load_config("config/vision_config.yaml")
trainer = JEPATrainer(config=config, custom_encoder=VisionEncoder)
trainer.train()
NLP Example
from transformers import AutoModel
from jepa import JEPA, JEPATrainer
# Use pre-trained transformer as encoder
class TransformerEncoder(nn.Module):
def __init__(self, model_name="bert-base-uncased"):
super().__init__()
self.transformer = AutoModel.from_pretrained(model_name)
def forward(self, input_ids, attention_mask=None):
outputs = self.transformer(input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state.mean(dim=1) # Pool over sequence
config = load_config("config/nlp_config.yaml")
trainer = JEPATrainer(config=config, custom_encoder=TransformerEncoder)
trainer.train()
Time Series Example
from jepa import create_dataset, JEPATrainer
# Create time series dataset
dataset = create_dataset(
data_path="data/timeseries.csv",
sequence_length=50,
prediction_length=10,
features=['sensor1', 'sensor2', 'sensor3']
)
config = load_config("config/timeseries_config.yaml")
trainer = JEPATrainer(config=config, train_dataset=dataset)
trainer.train()
📖 Documentation
- Full Documentation - Complete API reference and guides
- Installation Guide - Detailed installation instructions
- Configuration Guide - How to configure your experiments
- Training Guide - Training best practices
- API Reference - Complete API documentation
🔧 Development
Setting up Development Environment
git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install in development mode
pip install -e ".[dev,docs]"
# Install pre-commit hooks
pre-commit install
Running Tests
# Run all tests
pytest
# Run with coverage
pytest --cov=jepa --cov-report=html
# Run specific test
pytest tests/test_model.py::test_jepa_forward
Code Quality
# Format code
black jepa/
isort jepa/
# Type checking
mypy jepa/
# Linting
flake8 jepa/
🤝 Contributing
We welcome contributions! Please see our Contributing Guide for details.
Ways to Contribute
- 🐛 Bug Reports: Submit detailed bug reports
- ✨ Feature Requests: Suggest new features or improvements
- 📖 Documentation: Improve documentation and examples
- 🔧 Code: Submit pull requests with bug fixes or new features
- 🎯 Use Cases: Share your JEPA applications and results
Development Workflow
- Fork the repository
- Create a feature branch:
git checkout -b feature-name - Make your changes and add tests
- Ensure all tests pass:
pytest - Submit a pull request
📄 Citation
If you use JEPA in your research, please cite:
@software{jepa2025,
title = {JEPA: Joint-Embedding Predictive Architecture Framework},
author = {Venkatesh, Dilip},
year = {2025},
url = {https://github.com/dipsivenkatesh/jepa},
version = {0.1.0}
}
📝 License
This project is licensed under the MIT License - see the LICENSE file for details.
🙏 Acknowledgments
- Inspired by the original JEPA paper and Meta's research
- Built with PyTorch, Transformers, and other amazing open-source libraries
- Thanks to all contributors and users of the framework
📞 Support
- GitHub Issues: Report bugs or request features
- Documentation: Read the full documentation
- Discussions: Join community discussions
Steps to push latest version: rm -rf dist build .egg-info python -m build twine upload dist/
⭐ Star this repo | 📖 Read the docs | 🐛 Report issues
Built with ❤️ by Dilip Venkatesh
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jepa-0.1.11.tar.gz.
File metadata
- Download URL: jepa-0.1.11.tar.gz
- Upload date:
- Size: 147.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a5e6512d2055a3dea29b4a95b5ae74cdadf4e4493c163410dcd1b841d8b6960f
|
|
| MD5 |
291724e1de0b388e4635549b0690488f
|
|
| BLAKE2b-256 |
c36b425f608ee0f74a616aa7f0622ca9d152f5ee2bd8cfda9ba9054dfb17b524
|
File details
Details for the file jepa-0.1.11-py3-none-any.whl.
File metadata
- Download URL: jepa-0.1.11-py3-none-any.whl
- Upload date:
- Size: 73.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
538951cb269638ae0e02471d8aa6b6517cccdd9339bac04c517200878499da13
|
|
| MD5 |
bb689a09a3638d0beec0f4783dbf7f30
|
|
| BLAKE2b-256 |
f71884d7cd67555f9dcbdf23a12570c5066a93cc20273c838a3a2cd511ddb3a2
|