Skip to main content

Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching. Decode what someone saw, heard, or thought from fMRI.

Project description

CortexFlow

Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching.

Decode what someone saw, heard, or thought from fMRI brain activity using a modern generative backbone — the same DiT + Rectified Flow architecture behind FLUX, Stable Diffusion 3, and Wan2.1.

Architecture

fMRI voxels
    → BrainEncoder (MLP projector → global embedding + token sequence)
    → DiffusionTransformer (AdaLN-Zero conditioning + cross-attention)
    → Rectified Flow Matching (linear ODE: noise → data)
    → Modality Decoder (VAE / Griffin-Lim / autoregressive)
    → Image / Audio / Text

Key Components

Module Description
DiffusionTransformer DiT backbone with AdaLN-Zero, QK-Norm, SwiGLU, cross-attention
RectifiedFlowMatcher Linear interpolation paths, logit-normal sampling, Euler/midpoint ODE
BrainEncoder fMRI → global embedding (AdaLN) + token sequence (cross-attention)
LatentVAE / AudioVAE Lightweight VAE for image/audio latent compression
Brain2Image Full brain → image pipeline with classifier-free guidance
Brain2Audio Brain → mel spectrogram → waveform (Griffin-Lim)
Brain2Text Brain → transformer decoder → autoregressive text (byte-level)

Installation

pip install cortexflowx

With audio support:

pip install cortexflowx[audio]

Quick Start

Brain → Image

import torch
from cortexflow import build_brain2img, BrainData

model = build_brain2img(n_voxels=15000, img_size=256, hidden_dim=768, depth=12)

fmri = torch.randn(1, 15000)  # your fMRI data
brain = BrainData(voxels=fmri, subject_id="sub-01")

# Reconstruct
model.eval()
result = model.reconstruct(brain, num_steps=50, cfg_scale=4.0)
image = result.output  # (1, 3, 256, 256)

Brain → Audio

from cortexflow import build_brain2audio, BrainData

model = build_brain2audio(n_voxels=15000, n_mels=80, audio_len=256)

result = model.reconstruct(BrainData(voxels=fmri), num_steps=50)
mel = result.output  # (1, 80, 256) mel spectrogram

# Convert to waveform
from cortexflow import Brain2Audio
waveform = Brain2Audio.mel_to_waveform(mel)

Brain → Text

from cortexflow import build_brain2text, BrainData

model = build_brain2text(n_voxels=15000, max_len=128, hidden_dim=512, depth=8)

result = model.reconstruct(BrainData(voxels=fmri), temperature=0.8, top_k=50)
print(result.metadata["texts"])  # ["The cat sat on the mat"]

Training

from cortexflow import build_brain2img, BrainData, Trainer, TrainingConfig

model = build_brain2img(n_voxels=15000, img_size=256)
trainer = Trainer(model, TrainingConfig(learning_rate=1e-4, batch_size=16))

# Your training loop
for images, fmri in dataloader:
    brain = BrainData(voxels=fmri)
    loss = trainer.train_step(
        {"stimulus": images, "fmri": fmri},
        loss_fn=lambda m, b: m.training_loss(b["stimulus"], BrainData(voxels=b["fmri"]))
    )

Advanced Features

ROI-Aware Brain Encoding

from cortexflow import ROIBrainEncoder

encoder = ROIBrainEncoder(
    roi_sizes={"V1": 2000, "V2": 1500, "FFA": 800, "PPA": 600, "A1": 1000},
    cond_dim=768,
)
brain_global, brain_tokens = encoder({"V1": v1_voxels, "V2": v2_voxels, ...})

Per-Subject Adaptation

from cortexflow import SubjectAdapter

adapter = SubjectAdapter(cond_dim=768, rank=16, n_subjects=8)
adapted = adapter(brain_global, subject_idx=torch.tensor([0]))

EMA for Stable Sampling

from cortexflow import EMAModel

ema = EMAModel(model, decay=0.9999)
# During training:
ema.update(model)
# For sampling:
originals = ema.apply_to(model)
result = model.reconstruct(brain_data)
ema.restore(model, originals)

Technical Details

Why DiT + Flow Matching?

The Diffusion Transformer with rectified flow matching is the current state-of-the-art generative backbone:

  • FLUX (Black Forest Labs): DiT + flow matching
  • Stable Diffusion 3 (Stability AI): MMDiT + rectified flow
  • Wan2.1 (Alibaba): DiT + flow matching, 14B params
  • Movie Gen (Meta): DiT + flow matching, 30B params

We bring this architecture to brain decoding, replacing the dated UNet backbones used in prior work (MindEye, Brain-Diffuser).

Flow Matching Objective

Rectified flow uses linear interpolation paths:

$$x_t = (1 - t) \cdot x_0 + t \cdot x_1$$

The model learns the velocity field $v_\theta(x_t, t, c)$ that transports noise to data:

$$\mathcal{L} = \mathbb{E}{t, x_0, x_1} \left[ | v\theta(x_t, t, c) - (x_1 - x_0) |^2 \right]$$

At inference, we solve the ODE from $t=0$ (noise) to $t=1$ (data) using Euler or midpoint methods.

AdaLN-Zero Conditioning

Brain embeddings modulate the transformer via Adaptive Layer Normalization:

$$h = \gamma(c) \odot \text{LayerNorm}(x) + \beta(c)$$

with zero-initialized gating for stable training (Peebles & Xie 2022).

References

  • Peebles & Xie (2022). "Scalable Diffusion Models with Transformers." arXiv:2212.09748
  • Esser et al. (2024). "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis." arXiv:2403.03206
  • Lipman et al. (2024). "Flow Matching Guide and Code." arXiv:2412.06264
  • Scotti et al. (2023). "Reconstructing the Mind's Eye: fMRI-to-Image with Contrastive Learning and Diffusion Priors."
  • Ozcelik & VanRullen (2023). "Brain-Diffuser: Natural Scene Reconstruction from fMRI Using a Latent Diffusion Model."

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cortexflowx-0.3.0.tar.gz (39.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cortexflowx-0.3.0-py3-none-any.whl (34.2 kB view details)

Uploaded Python 3

File details

Details for the file cortexflowx-0.3.0.tar.gz.

File metadata

  • Download URL: cortexflowx-0.3.0.tar.gz
  • Upload date:
  • Size: 39.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for cortexflowx-0.3.0.tar.gz
Algorithm Hash digest
SHA256 7222b0849b6b3ed77bf92c4be5b339ef6840c14edcedacc42ec5dd52e2731c68
MD5 ded4b4eef31d38e7b444619d866b87f4
BLAKE2b-256 9152dd6041014bdce320dc8c8ff5043b81ea6bd008d8769030078d164d43d029

See more details on using hashes here.

File details

Details for the file cortexflowx-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: cortexflowx-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 34.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for cortexflowx-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1b4331798c7de87fa8903ec43d7cc04526417e6d38cfe725601d74a1d7f94891
MD5 f48a8eae549a9e17d7005867268025ba
BLAKE2b-256 4dc070ce07b32916d538118adc557b3cb0628b895cbd15316400fd5a4317e963

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page