Skip to main content

Joint-Embedding Predictive Architecture for Self-Supervised Learning

Project description

JEPA Framework

JEPA Logo PyPI version Python 3.8+ License: MIT Documentation

A powerful self-supervised learning framework for Joint-Embedding Predictive Architecture (JEPA)

InstallationQuick StartDocumentationExamplesContributing

🚀 Overview

JEPA (Joint-Embedding Predictive Architecture) is a cutting-edge self-supervised learning framework that learns rich representations by predicting parts of the input from other parts. This implementation provides a flexible, production-ready framework for training JEPA models across multiple modalities.

Key Features

🔧 Modular Design

  • Flexible encoder-predictor architecture
  • Support for any PyTorch model as encoder/predictor
  • Easy to extend and customize for your specific needs

🌍 Multi-Modal Support

  • Computer Vision: Images, videos, medical imaging
  • Natural Language Processing: Text, documents, code
  • Time Series: Sequential data, forecasting, anomaly detection
  • Audio: Speech, music, environmental sounds
  • Multimodal: Vision-language, audio-visual learning

High Performance

  • Mixed precision training (FP16/BF16)
  • Native DistributedDataParallel (DDP) support
  • Memory-efficient implementations
  • Optimized for both research and production

📊 Comprehensive Logging

  • Weights & Biases integration
  • TensorBoard support
  • Console logging with rich formatting
  • Multi-backend logging system

🎛️ Production Ready

  • CLI interface for easy deployment
  • Flexible YAML configuration system
  • Comprehensive testing suite
  • Docker support and containerization
  • Type hints throughout

🏗️ Architecture

JEPA follows a simple yet powerful architecture:

graph LR
    A[Input Data] --> B[Context/Target Split]
    B --> C[Encoder]
    C --> D[Joint Embedding Space]
    D --> E[Predictor]
    E --> F[Target Prediction]
    F --> G[Loss Computation]

The model learns by:

  1. Splitting input into context and target regions
  2. Encoding both context and target separately
  3. Predicting target embeddings from context embeddings
  4. Learning representations that capture meaningful relationships

📦 Installation

From PyPI (Recommended)

pip install jepa

From Source

git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
pip install -e .

Development Installation

git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
pip install -e ".[dev,docs]"

Docker

docker pull dipsivenkatesh/jepa:latest
docker run -it dipsivenkatesh/jepa:latest

🚀 Quick Start

Python API

import torch
from torch.utils.data import DataLoader, TensorDataset

from jepa.models import JEPA
from jepa.models.encoder import Encoder
from jepa.models.predictor import Predictor
from jepa.trainer import create_trainer

# Toy dataset of (state_t, state_t1) pairs
dataset = TensorDataset(torch.randn(256, 16, 128), torch.randn(256, 16, 128))
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Build model components
encoder = Encoder(hidden_dim=128)
predictor = Predictor(hidden_dim=128)
model = JEPA(encoder=encoder, predictor=predictor)

# Trainer with sensible defaults
trainer = create_trainer(model, learning_rate=3e-4, device="auto")

# Train for a couple of epochs
trainer.train(train_loader, num_epochs=2)

# Optional: stream metrics to Weights & Biases
trainer_ddp = create_trainer(
    model,
    learning_rate=3e-4,
    device="auto",
    logger="wandb",
    logger_project="jepa-experiments",
    logger_run_name="quickstart-run",
)

# Persist weights for downstream inference
model.save_pretrained("artifacts/jepa-small")

# Reload using the same model class
reloaded = JEPA.from_pretrained("artifacts/jepa-small", encoder=encoder, predictor=predictor)

Distributed Training (DDP)

Launch multi-GPU jobs with PyTorch's launcher:

torchrun --nproc_per_node=4 scripts/train.py --config config/default_config.yaml

Inside your training script, enable DDP when you create the trainer:

trainer = create_trainer(
    model,
    distributed=True,
    world_size=int(os.environ["WORLD_SIZE"]),
    local_rank=int(os.environ.get("LOCAL_RANK", 0)),
)

The trainer wraps the model in DistributedDataParallel, synchronizes losses, and restricts logging/checkpointing to rank zero automatically.

Action-Conditioned Variant

Use JEPAAction when actions influence the next state. Provide a state encoder, an action encoder, and a predictor that consumes the concatenated [z_t, a_t] embedding.

from jepa import JEPAAction
import torch.nn as nn

state_dim = 512
action_dim = 64

# Example encoders (replace with your own)
state_encoder = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, state_dim),
)
action_encoder = nn.Sequential(
    nn.Linear(10, 128), nn.ReLU(),
    nn.Linear(128, action_dim),
)

# Predictor takes [state_dim + action_dim] → state_dim
predictor = nn.Sequential(
    nn.Linear(state_dim + action_dim, 512), nn.ReLU(),
    nn.Linear(512, state_dim),
)

model = JEPAAction(state_encoder, action_encoder, predictor)

Command Line Interface

# Train a model
jepa-train --config config/default_config.yaml

# Train with custom parameters
jepa-train --config config/vision_config.yaml \
           --batch-size 64 \
           --learning-rate 0.001 \
           --num-epochs 100

# Evaluate a trained model
jepa-evaluate --config config/default_config.yaml \
              --checkpoint checkpoints/best_model.pth

# Generate a configuration template
jepa-train --generate-config my_config.yaml

# Get help
jepa-train --help

Configuration

JEPA uses YAML configuration files for easy experiment management:

# config/my_experiment.yaml
model:
  encoder_type: "transformer"
  encoder_dim: 768
  predictor_type: "mlp"
  predictor_hidden_dim: 2048

training:
  batch_size: 32
  learning_rate: 0.0001
  num_epochs: 100
  warmup_epochs: 10

data:
  train_data_path: "data/train"
  val_data_path: "data/val"
  sequence_length: 16

logging:
  wandb:
    enabled: true
    project: "jepa-experiments"
  tensorboard:
    enabled: true
    log_dir: "./tb_logs"

🎯 Use Cases

Computer Vision

  • Image Classification: Pre-train backbones for downstream tasks
  • Object Detection: Learn robust visual representations
  • Medical Imaging: Analyze medical scans and imagery
  • Satellite Imagery: Process large-scale geographic data

Natural Language Processing

  • Language Models: Pre-train transformer architectures
  • Document Understanding: Learn document-level representations
  • Code Analysis: Understand code structure and semantics
  • Cross-lingual Learning: Build multilingual representations

Time Series Analysis

  • Forecasting: Pre-train models for prediction tasks
  • Anomaly Detection: Learn normal patterns in sequential data
  • Financial Modeling: Analyze market trends and patterns
  • IoT Sensors: Process sensor data streams

Multimodal Learning

  • Vision-Language: Combine images and text understanding
  • Audio-Visual: Learn from synchronized audio and video
  • Cross-Modal Retrieval: Search across different modalities
  • Embodied AI: Integrate multiple sensor modalities

📚 Examples

Vision Example

from jepa import JEPADataset, JEPATrainer, load_config
import torch.nn as nn

# Custom vision encoder
class VisionEncoder(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=512):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_dim, 64, 7, 2, 3),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, hidden_dim)
        )
    
    def forward(self, x):
        return self.conv_layers(x)

# Load config and customize
config = load_config("config/vision_config.yaml")
trainer = JEPATrainer(config=config, custom_encoder=VisionEncoder)
trainer.train()

NLP Example

from transformers import AutoModel
from jepa import JEPA, JEPATrainer

# Use pre-trained transformer as encoder
class TransformerEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(model_name)
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.transformer(input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state.mean(dim=1)  # Pool over sequence

config = load_config("config/nlp_config.yaml")
trainer = JEPATrainer(config=config, custom_encoder=TransformerEncoder)
trainer.train()

Time Series Example

from jepa import create_dataset, JEPATrainer

# Create time series dataset
dataset = create_dataset(
    data_path="data/timeseries.csv",
    sequence_length=50,
    prediction_length=10,
    features=['sensor1', 'sensor2', 'sensor3']
)

config = load_config("config/timeseries_config.yaml")
trainer = JEPATrainer(config=config, train_dataset=dataset)
trainer.train()

📖 Documentation

🔧 Development

Setting up Development Environment

git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install in development mode
pip install -e ".[dev,docs]"

# Install pre-commit hooks
pre-commit install

Running Tests

# Run all tests
pytest

# Run with coverage
pytest --cov=jepa --cov-report=html

# Run specific test
pytest tests/test_model.py::test_jepa_forward

Code Quality

# Format code
black jepa/
isort jepa/

# Type checking
mypy jepa/

# Linting
flake8 jepa/

🤝 Contributing

We welcome contributions! Please see our Contributing Guide for details.

Ways to Contribute

  • 🐛 Bug Reports: Submit detailed bug reports
  • Feature Requests: Suggest new features or improvements
  • 📖 Documentation: Improve documentation and examples
  • 🔧 Code: Submit pull requests with bug fixes or new features
  • 🎯 Use Cases: Share your JEPA applications and results

Development Workflow

  1. Fork the repository
  2. Create a feature branch: git checkout -b feature-name
  3. Make your changes and add tests
  4. Ensure all tests pass: pytest
  5. Submit a pull request

📄 Citation

If you use JEPA in your research, please cite:

@software{jepa2025,
  title = {JEPA: Joint-Embedding Predictive Architecture Framework},
  author = {Venkatesh, Dilip},
  year = {2025},
  url = {https://github.com/dipsivenkatesh/jepa},
  version = {0.1.0}
}

📝 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙏 Acknowledgments

  • Inspired by the original JEPA paper and Meta's research
  • Built with PyTorch, Transformers, and other amazing open-source libraries
  • Thanks to all contributors and users of the framework

📞 Support

Steps to push latest version: rm -rf dist build .egg-info python -m build twine upload dist/


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jepa-0.1.11.tar.gz (147.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jepa-0.1.11-py3-none-any.whl (73.9 kB view details)

Uploaded Python 3

File details

Details for the file jepa-0.1.11.tar.gz.

File metadata

  • Download URL: jepa-0.1.11.tar.gz
  • Upload date:
  • Size: 147.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for jepa-0.1.11.tar.gz
Algorithm Hash digest
SHA256 a5e6512d2055a3dea29b4a95b5ae74cdadf4e4493c163410dcd1b841d8b6960f
MD5 291724e1de0b388e4635549b0690488f
BLAKE2b-256 c36b425f608ee0f74a616aa7f0622ca9d152f5ee2bd8cfda9ba9054dfb17b524

See more details on using hashes here.

File details

Details for the file jepa-0.1.11-py3-none-any.whl.

File metadata

  • Download URL: jepa-0.1.11-py3-none-any.whl
  • Upload date:
  • Size: 73.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for jepa-0.1.11-py3-none-any.whl
Algorithm Hash digest
SHA256 538951cb269638ae0e02471d8aa6b6517cccdd9339bac04c517200878499da13
MD5 bb689a09a3638d0beec0f4783dbf7f30
BLAKE2b-256 f71884d7cd67555f9dcbdf23a12570c5066a93cc20273c838a3a2cd511ddb3a2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page