Skip to main content

Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching. Decode what someone saw, heard, or thought from fMRI.

Project description

CortexFlow

Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching.

Decode what someone saw, heard, or thought from fMRI brain activity using a modern generative backbone — the same DiT + Rectified Flow architecture behind FLUX, Stable Diffusion 3, and Wan2.1.

Architecture

fMRI voxels
    → BrainEncoder (MLP projector → global embedding + token sequence)
    → DiffusionTransformer (AdaLN-Zero conditioning + cross-attention)
    → Rectified Flow Matching (linear ODE: noise → data)
    → Modality Decoder (VAE / Griffin-Lim / autoregressive)
    → Image / Audio / Text

Key Components

Module Description
DiffusionTransformer DiT backbone with AdaLN-Zero, QK-Norm, SwiGLU, cross-attention
RectifiedFlowMatcher Linear interpolation paths, logit-normal sampling, Euler/midpoint ODE
BrainEncoder fMRI → global embedding (AdaLN) + token sequence (cross-attention)
LatentVAE / AudioVAE Lightweight VAE for image/audio latent compression
Brain2Image Full brain → image pipeline with classifier-free guidance
Brain2Audio Brain → mel spectrogram → waveform (Griffin-Lim)
Brain2Text Brain → transformer decoder → autoregressive text (byte-level)

Installation

pip install cortexflow

With audio support:

pip install cortexflow[audio]

Quick Start

Brain → Image

import torch
from cortexflow import build_brain2img, BrainData

model = build_brain2img(n_voxels=15000, img_size=256, hidden_dim=768, depth=12)

fmri = torch.randn(1, 15000)  # your fMRI data
brain = BrainData(voxels=fmri, subject_id="sub-01")

# Reconstruct
model.eval()
result = model.reconstruct(brain, num_steps=50, cfg_scale=4.0)
image = result.output  # (1, 3, 256, 256)

Brain → Audio

from cortexflow import build_brain2audio, BrainData

model = build_brain2audio(n_voxels=15000, n_mels=80, audio_len=256)

result = model.reconstruct(BrainData(voxels=fmri), num_steps=50)
mel = result.output  # (1, 80, 256) mel spectrogram

# Convert to waveform
from cortexflow import Brain2Audio
waveform = Brain2Audio.mel_to_waveform(mel)

Brain → Text

from cortexflow import build_brain2text, BrainData

model = build_brain2text(n_voxels=15000, max_len=128, hidden_dim=512, depth=8)

result = model.reconstruct(BrainData(voxels=fmri), temperature=0.8, top_k=50)
print(result.metadata["texts"])  # ["The cat sat on the mat"]

Training

from cortexflow import build_brain2img, BrainData, Trainer, TrainingConfig

model = build_brain2img(n_voxels=15000, img_size=256)
trainer = Trainer(model, TrainingConfig(learning_rate=1e-4, batch_size=16))

# Your training loop
for images, fmri in dataloader:
    brain = BrainData(voxels=fmri)
    loss = trainer.train_step(
        {"stimulus": images, "fmri": fmri},
        loss_fn=lambda m, b: m.training_loss(b["stimulus"], BrainData(voxels=b["fmri"]))
    )

Advanced Features

ROI-Aware Brain Encoding

from cortexflow import ROIBrainEncoder

encoder = ROIBrainEncoder(
    roi_sizes={"V1": 2000, "V2": 1500, "FFA": 800, "PPA": 600, "A1": 1000},
    cond_dim=768,
)
brain_global, brain_tokens = encoder({"V1": v1_voxels, "V2": v2_voxels, ...})

Per-Subject Adaptation

from cortexflow import SubjectAdapter

adapter = SubjectAdapter(cond_dim=768, rank=16, n_subjects=8)
adapted = adapter(brain_global, subject_idx=torch.tensor([0]))

EMA for Stable Sampling

from cortexflow import EMAModel

ema = EMAModel(model, decay=0.9999)
# During training:
ema.update(model)
# For sampling:
originals = ema.apply_to(model)
result = model.reconstruct(brain_data)
ema.restore(model, originals)

Technical Details

Why DiT + Flow Matching?

The Diffusion Transformer with rectified flow matching is the current state-of-the-art generative backbone:

  • FLUX (Black Forest Labs): DiT + flow matching
  • Stable Diffusion 3 (Stability AI): MMDiT + rectified flow
  • Wan2.1 (Alibaba): DiT + flow matching, 14B params
  • Movie Gen (Meta): DiT + flow matching, 30B params

We bring this architecture to brain decoding, replacing the dated UNet backbones used in prior work (MindEye, Brain-Diffuser).

Flow Matching Objective

Rectified flow uses linear interpolation paths:

$$x_t = (1 - t) \cdot x_0 + t \cdot x_1$$

The model learns the velocity field $v_\theta(x_t, t, c)$ that transports noise to data:

$$\mathcal{L} = \mathbb{E}{t, x_0, x_1} \left[ | v\theta(x_t, t, c) - (x_1 - x_0) |^2 \right]$$

At inference, we solve the ODE from $t=0$ (noise) to $t=1$ (data) using Euler or midpoint methods.

AdaLN-Zero Conditioning

Brain embeddings modulate the transformer via Adaptive Layer Normalization:

$$h = \gamma(c) \odot \text{LayerNorm}(x) + \beta(c)$$

with zero-initialized gating for stable training (Peebles & Xie 2022).

References

  • Peebles & Xie (2022). "Scalable Diffusion Models with Transformers." arXiv:2212.09748
  • Esser et al. (2024). "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis." arXiv:2403.03206
  • Lipman et al. (2024). "Flow Matching Guide and Code." arXiv:2412.06264
  • Scotti et al. (2023). "Reconstructing the Mind's Eye: fMRI-to-Image with Contrastive Learning and Diffusion Priors."
  • Ozcelik & VanRullen (2023). "Brain-Diffuser: Natural Scene Reconstruction from fMRI Using a Latent Diffusion Model."

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cortexflowx-0.1.0.tar.gz (36.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cortexflowx-0.1.0-py3-none-any.whl (32.3 kB view details)

Uploaded Python 3

File details

Details for the file cortexflowx-0.1.0.tar.gz.

File metadata

  • Download URL: cortexflowx-0.1.0.tar.gz
  • Upload date:
  • Size: 36.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for cortexflowx-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4a827ddbe6ffb8343a974327f3b6195164d62da892c6c8d4ea2640a08a99926a
MD5 2c8f023db2c28bb8a92ed7c04a4adb84
BLAKE2b-256 199cc9109276620d946eaa7f5f25988a664686010cdc005ce61eddf80a24a260

See more details on using hashes here.

File details

Details for the file cortexflowx-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: cortexflowx-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 32.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for cortexflowx-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c4b2b7758867b1ecd90b500215688c8c04647763749ee3c8679a1f9447383f9c
MD5 bf8674a0dbd54bb1cfb04c66cd1e7897
BLAKE2b-256 cc1f9f00381b964bb6a521ae5515f7962742d81efc5bb9282527b3b559547519

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page