Skip to main content

Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching. Decode what someone saw, heard, or thought from fMRI.

Project description

CortexFlow

Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching.

Decode what someone saw, heard, or thought from fMRI brain activity using a modern generative backbone — the same DiT + Rectified Flow architecture behind FLUX, Stable Diffusion 3, and Wan2.1.

Architecture

fMRI voxels
    → BrainEncoder (MLP projector → global embedding + token sequence)
    → DiffusionTransformer (AdaLN-Zero conditioning + cross-attention)
    → Rectified Flow Matching (linear ODE: noise → data)
    → Modality Decoder (VAE / Griffin-Lim / autoregressive)
    → Image / Audio / Text

Key Components

Module Description
DiffusionTransformer DiT backbone with AdaLN-Zero, QK-Norm, SwiGLU, cross-attention
RectifiedFlowMatcher Linear interpolation paths, logit-normal sampling, Euler/midpoint ODE
BrainEncoder fMRI → global embedding (AdaLN) + token sequence (cross-attention)
LatentVAE / AudioVAE Lightweight VAE for image/audio latent compression
Brain2Image Full brain → image pipeline with classifier-free guidance
Brain2Audio Brain → mel spectrogram → waveform (Griffin-Lim)
Brain2Text Brain → transformer decoder → autoregressive text (byte-level)

Installation

pip install cortexflowx

With audio support:

pip install cortexflowx[audio]

Quick Start

Brain → Image

import torch
from cortexflow import build_brain2img, BrainData

model = build_brain2img(n_voxels=15000, img_size=256, hidden_dim=768, depth=12)

fmri = torch.randn(1, 15000)  # your fMRI data
brain = BrainData(voxels=fmri, subject_id="sub-01")

# Reconstruct
model.eval()
result = model.reconstruct(brain, num_steps=50, cfg_scale=4.0)
image = result.output  # (1, 3, 256, 256)

Brain → Audio

from cortexflow import build_brain2audio, BrainData

model = build_brain2audio(n_voxels=15000, n_mels=80, audio_len=256)

result = model.reconstruct(BrainData(voxels=fmri), num_steps=50)
mel = result.output  # (1, 80, 256) mel spectrogram

# Convert to waveform
from cortexflow import Brain2Audio
waveform = Brain2Audio.mel_to_waveform(mel)

Brain → Text

from cortexflow import build_brain2text, BrainData

model = build_brain2text(n_voxels=15000, max_len=128, hidden_dim=512, depth=8)

result = model.reconstruct(BrainData(voxels=fmri), temperature=0.8, top_k=50)
print(result.metadata["texts"])  # ["The cat sat on the mat"]

Training

from cortexflow import build_brain2img, BrainData, Trainer, TrainingConfig

model = build_brain2img(n_voxels=15000, img_size=256)
trainer = Trainer(model, TrainingConfig(learning_rate=1e-4, batch_size=16))

# Your training loop
for images, fmri in dataloader:
    brain = BrainData(voxels=fmri)
    loss = trainer.train_step(
        {"stimulus": images, "fmri": fmri},
        loss_fn=lambda m, b: m.training_loss(b["stimulus"], BrainData(voxels=b["fmri"]))
    )

Advanced Features

ROI-Aware Brain Encoding

from cortexflow import ROIBrainEncoder

encoder = ROIBrainEncoder(
    roi_sizes={"V1": 2000, "V2": 1500, "FFA": 800, "PPA": 600, "A1": 1000},
    cond_dim=768,
)
brain_global, brain_tokens = encoder({"V1": v1_voxels, "V2": v2_voxels, ...})

Per-Subject Adaptation

from cortexflow import SubjectAdapter

adapter = SubjectAdapter(cond_dim=768, rank=16, n_subjects=8)
adapted = adapter(brain_global, subject_idx=torch.tensor([0]))

EMA for Stable Sampling

from cortexflow import EMAModel

ema = EMAModel(model, decay=0.9999)
# During training:
ema.update(model)
# For sampling:
originals = ema.apply_to(model)
result = model.reconstruct(brain_data)
ema.restore(model, originals)

Technical Details

Why DiT + Flow Matching?

The Diffusion Transformer with rectified flow matching is the current state-of-the-art generative backbone:

  • FLUX (Black Forest Labs): DiT + flow matching
  • Stable Diffusion 3 (Stability AI): MMDiT + rectified flow
  • Wan2.1 (Alibaba): DiT + flow matching, 14B params
  • Movie Gen (Meta): DiT + flow matching, 30B params

We bring this architecture to brain decoding, replacing the dated UNet backbones used in prior work (MindEye, Brain-Diffuser).

Flow Matching Objective

Rectified flow uses linear interpolation paths:

$$x_t = (1 - t) \cdot x_0 + t \cdot x_1$$

The model learns the velocity field $v_\theta(x_t, t, c)$ that transports noise to data:

$$\mathcal{L} = \mathbb{E}{t, x_0, x_1} \left[ | v\theta(x_t, t, c) - (x_1 - x_0) |^2 \right]$$

At inference, we solve the ODE from $t=0$ (noise) to $t=1$ (data) using Euler or midpoint methods.

AdaLN-Zero Conditioning

Brain embeddings modulate the transformer via Adaptive Layer Normalization:

$$h = \gamma(c) \odot \text{LayerNorm}(x) + \beta(c)$$

with zero-initialized gating for stable training (Peebles & Xie 2022).

References

  • Peebles & Xie (2022). "Scalable Diffusion Models with Transformers." arXiv:2212.09748
  • Esser et al. (2024). "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis." arXiv:2403.03206
  • Lipman et al. (2024). "Flow Matching Guide and Code." arXiv:2412.06264
  • Scotti et al. (2023). "Reconstructing the Mind's Eye: fMRI-to-Image with Contrastive Learning and Diffusion Priors."
  • Ozcelik & VanRullen (2023). "Brain-Diffuser: Natural Scene Reconstruction from fMRI Using a Latent Diffusion Model."

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cortexflowx-0.4.0.tar.gz (41.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cortexflowx-0.4.0-py3-none-any.whl (34.9 kB view details)

Uploaded Python 3

File details

Details for the file cortexflowx-0.4.0.tar.gz.

File metadata

  • Download URL: cortexflowx-0.4.0.tar.gz
  • Upload date:
  • Size: 41.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for cortexflowx-0.4.0.tar.gz
Algorithm Hash digest
SHA256 afe9793009387d8379c34f47f08c884ff99ed5d5cf19f0569505b1d115911f4c
MD5 e9e955cad3f80846bd345f937190de94
BLAKE2b-256 ade8c0d1c8bed42548e4883f2e1b59da2d4740623d1a37fbb6ee1c16eae8088f

See more details on using hashes here.

File details

Details for the file cortexflowx-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: cortexflowx-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 34.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for cortexflowx-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c048370f4dd145369deb74d8dadb32dd30aa62b279420580a82e10f183d43e13
MD5 b6e0182c4840e1cea06eed39bf5bd361
BLAKE2b-256 59d187de8028c3e5f367ed748b7a1793576fd12951a0f76e929466f4b038f907

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page