Skip to main content

SSM-As-VLM-Bridge: State Space Models as Vision-Language Bridge Layers

Project description

SSM-As-VLM-Bridge: State Space Models as Vision-Language Bridge Layers

Join our Discord Subscribe on YouTube Connect on LinkedIn Follow on X.com

License: MIT Python 3.10+ PyTorch

๐Ÿ“– Abstract

This repository presents an exploration into leveraging State Space Models (SSMs) as bridge/adapter layers for Vision-Language Models (VLMs). The project introduces a novel architecture that uses SSMs to facilitate cross-modal understanding between visual and textual representations, potentially offering more efficient and interpretable vision-language fusion compared to traditional attention-based approaches.

๐Ÿ—๏ธ Model Architecture

Overview

The SSM-As-VLM-Bridge architecture consists of three main components:

  1. Vision Encoder: Enhanced Vision Transformer (ViT) for image feature extraction
  2. SSM Bridge: Multi-layer State Space Model for cross-modal fusion
  3. Language Model: Transformer-based decoder for text generation

Detailed Architecture

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚   Vision Input  โ”‚    โ”‚   Text Input    โ”‚    โ”‚   Generated     โ”‚
โ”‚   (224ร—224ร—3)   โ”‚    โ”‚   (Tokens)      โ”‚    โ”‚   Text Output   โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
          โ”‚                      โ”‚                      โ”‚
          โ–ผ                      โ–ผ                      โ–ฒ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”              โ”‚
โ”‚  Vision Encoder โ”‚    โ”‚  Token Embeddingโ”‚              โ”‚
โ”‚  (Enhanced ViT) โ”‚    โ”‚  + Positional   โ”‚              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜              โ”‚
          โ”‚                      โ”‚                      โ”‚
          โ–ผ                      โ”‚                      โ”‚
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”              โ”‚                      โ”‚
โ”‚   SSM Bridge    โ”‚              โ”‚                      โ”‚
โ”‚  (Multi-layer   โ”‚              โ”‚                      โ”‚
โ”‚   State Space   โ”‚              โ”‚                      โ”‚
โ”‚   Model)        โ”‚              โ”‚                      โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜              โ”‚                      โ”‚
          โ”‚                      โ”‚                      โ”‚
          โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                      โ”‚
                     โ”‚                                  โ”‚
                     โ–ผ                                  โ”‚
            โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                         โ”‚
            โ”‚ Cross-Modal     โ”‚                         โ”‚
            โ”‚ Attention       โ”‚                         โ”‚
            โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                         โ”‚
                      โ”‚                                 โ”‚
                      โ–ผ                                 โ”‚
            โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                         โ”‚
            โ”‚ Language        โ”‚                         โ”‚
            โ”‚ Transformer     โ”‚                         โ”‚
            โ”‚ (Decoder)       โ”‚                         โ”‚
            โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                         โ”‚
                      โ”‚                                 โ”‚
                      โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Key Components

1. Enhanced Vision Transformer (ViT)

  • Patch Size: 16ร—16 pixels
  • Embedding Dimension: 768
  • Number of Layers: 12
  • Number of Heads: 12
  • Input Resolution: 224ร—224ร—3

2. SSM Bridge Layer

  • State Dimension: 64
  • Hidden Dimension: 256
  • Number of Layers: 4
  • Dropout: 0.1
  • Activation: Enhanced Swish with numerical stability

3. Cross-Modal Attention

  • Number of Layers: 2
  • Number of Heads: 8
  • Multi-Query Attention: Efficient attention with shared key-value heads

4. Language Model

  • Vocabulary Size: 32,000
  • Embedding Dimension: 768
  • Number of Layers: 12
  • Number of Heads: 12
  • KV Heads: 1 (Multi-Query)
  • Max Sequence Length: 2,048

๐Ÿš€ Installation

Prerequisites

  • Python 3.10+
  • PyTorch 2.0+
  • CUDA (optional, for GPU acceleration)

Install from Source

# Clone the repository
git clone https://github.com/kyegomez/SSM-As-VLM-Bridge.git
cd SSM-As-VLM-Bridge

# Install in development mode
pip install -e .

Install Dependencies

pip install -r requirements.txt

๐Ÿ’ป Usage

Basic Usage

import torch
from ssm_bridge.model import EnhancedVLM, VLMConfig

# Create configuration
config = VLMConfig(
    img_size=224,
    patch_size=16,
    vision_embed_dim=768,
    vision_num_layers=12,
    vision_num_heads=12,
    vocab_size=32000,
    text_embed_dim=768,
    text_num_layers=12,
    text_num_heads=12,
    text_kv_heads=1,
    max_seq_length=2048,
    ssm_state_dim=64,
    ssm_hidden_dim=256,
    ssm_num_layers=4,
    ssm_dropout=0.1,
    cross_attn_layers=2,
    cross_attn_heads=8,
    dropout=0.1,
    layer_norm_eps=1e-5,
)

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EnhancedVLM(config).to(device)

# Prepare input data
images = torch.randn(2, 3, 224, 224).to(device)  # Batch of images
tokens = torch.randint(0, 32000, (2, 10)).to(device)  # Text tokens
targets = torch.randint(0, 32000, (2, 10)).to(device)  # Target tokens

# Forward pass
logits, loss = model(images, tokens, targets)
print(f"Logits shape: {logits.shape}, Loss: {loss.item():.4f}")

# Text generation
generated_tokens = model.generate(
    images, 
    max_length=50, 
    temperature=0.8, 
    top_k=50,
    top_p=0.9
)
print(f"Generated tokens shape: {generated_tokens.shape}")

Advanced Configuration

# Custom SSM configuration for different use cases
config = VLMConfig(
    # Vision settings
    img_size=384,  # Higher resolution
    patch_size=16,
    vision_embed_dim=1024,  # Larger embeddings
    vision_num_layers=24,   # Deeper vision encoder
    
    # SSM Bridge settings
    ssm_state_dim=128,      # Larger state dimension
    ssm_hidden_dim=512,     # Larger hidden dimension
    ssm_num_layers=6,       # More SSM layers
    
    # Language settings
    text_embed_dim=1024,    # Match vision embedding
    text_num_layers=24,     # Deeper language model
    max_seq_length=4096,    # Longer sequences
)

๐Ÿ”ฌ Research Contributions

Novel SSM Bridge Design

The SSM bridge introduces several innovations:

  1. Multi-layer SSM Architecture: Stacked SSM layers with residual connections
  2. Enhanced Numerical Stability: Clamped activations and proper initialization
  3. Gating Mechanisms: Adaptive gating for better information flow
  4. Cross-modal State Management: Efficient state transitions between modalities

Key Features

  • Efficient Attention: Multi-Query Attention reduces computational complexity
  • Numerical Stability: Enhanced Swish activation and proper gradient flow
  • Modular Design: Easy to modify and extend components
  • Memory Efficient: Optimized for large-scale training

๐Ÿ“Š Model Specifications

Component Parameter Count Memory Usage FLOPs
Vision Encoder ~86M ~2GB ~15G
SSM Bridge ~2M ~0.5GB ~5G
Language Model ~86M ~2GB ~20G
Total ~174M ~4.5GB ~40G

Estimates for batch size 1, sequence length 2048, image size 224ร—224

๐ŸŽฏ TODO & Roadmap

Immediate Tasks (Next 2-4 weeks)

  • Train Base Model: Implement training pipeline with COCO/CC3M datasets
  • Benchmark Performance: Compare against CLIP, LLaVA, and other VLMs
  • Ablation Studies: Analyze SSM bridge effectiveness
  • Memory Optimization: Implement gradient checkpointing and mixed precision

Medium-term Goals (1-3 months)

  • All-SSM VLM: Explore replacing attention with SSMs throughout
  • Multi-modal SSM: Extend SSM to handle multiple modalities
  • Efficient Inference: Optimize for real-time applications
  • Pre-trained Models: Release checkpoints for various scales

Long-term Vision (3-6 months)

  • Large-scale Training: Train on web-scale datasets
  • Zero-shot Evaluation: Comprehensive evaluation on VQA, captioning, etc.
  • Deployment Pipeline: Easy deployment for production use
  • Community Models: Open-source ecosystem for SSM-based VLMs

๐Ÿงช Experiments & Benchmarks

Planned Evaluations

  1. Image Captioning

    • COCO Captions
    • Flickr30k
    • NoCaps
  2. Visual Question Answering

    • VQA v2.0
    • GQA
    • OK-VQA
  3. Zero-shot Classification

    • ImageNet-1k
    • ImageNet-21k
    • CIFAR-100
  4. Cross-modal Retrieval

    • MS-COCO retrieval
    • Flickr30k retrieval

Baseline Comparisons

  • CLIP: OpenAI's contrastive learning approach
  • LLaVA: Large Language and Vision Assistant
  • Flamingo: DeepMind's few-shot learning model
  • BLIP-2: Bootstrapping Language-Image Pre-training

๐Ÿค Contributing

We welcome contributions! Please see our Contributing Guidelines for details.

Development Setup

# Install development dependencies
pip install -e ".[dev]"

# Run tests
# Code formatting
black .
ruff check . --fix

๐Ÿ“š Citation

If you use this code in your research, please cite:

@misc{gomez2024ssmvlm,
  title={SSM-As-VLM-Bridge: State Space Models as Vision-Language Bridge Layers},
  author={Kye Gomez},
  year={2024},
  howpublished={\url{https://github.com/kyegomez/SSM-As-VLM-Bridge}},
  note={An exploration into leveraging SSMs as bridge/adapter layers for VLMs}
}

๐Ÿ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

๐Ÿ™ Acknowledgments

  • Kye Gomez - Original research and implementation
  • PyTorch Team - Excellent deep learning framework
  • Open Source Community - Inspiration and collaboration

๐Ÿ“ž Contact

๐Ÿ”— Related Work


Note: This is a research project. The model architecture and implementation are subject to ongoing development and improvement. Please check the Issues page for known limitations and planned features.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ssm_bridge-0.0.1.tar.gz (14.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ssm_bridge-0.0.1-py3-none-any.whl (11.4 kB view details)

Uploaded Python 3

File details

Details for the file ssm_bridge-0.0.1.tar.gz.

File metadata

  • Download URL: ssm_bridge-0.0.1.tar.gz
  • Upload date:
  • Size: 14.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.3 Darwin/24.5.0

File hashes

Hashes for ssm_bridge-0.0.1.tar.gz
Algorithm Hash digest
SHA256 564d871b6038d8811ce7a3ace15b91d0e127c736b463420e12942c5fe276895e
MD5 0427106b3610dc44986f7c0684a7c701
BLAKE2b-256 d5c81ae96c8aad79ad921d76923adfe7dd70bee52052889e5d73f2583c7593f2

See more details on using hashes here.

File details

Details for the file ssm_bridge-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: ssm_bridge-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 11.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.3 Darwin/24.5.0

File hashes

Hashes for ssm_bridge-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 2ad1b3b2a3fb38616e96aff32d6209c5d4f9e7fd933ee0a014adb052ce5e50f2
MD5 125a2e4b25bbf486a0f9acfc90272df3
BLAKE2b-256 251f7c736becaedf5e936c87c76dfe9c0228feae429aaee85804119323138066

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page