Skip to main content

Visual Brain AI - Multi-task brain MRI analysis library

Project description

Vbai - Visual Brain AI

Python 3.8+ PyTorch License: MIT

A PyTorch-based deep learning library for multi-task brain MRI analysis. Train models for dementia classification, brain tumor detection, and 3D volumetric NIfTI analysis with just a few lines of code.

Features

  • 2D & 3D Support: RGB image analysis and NIfTI (.nii/.nii.gz) volumetric processing
  • Flexible Task Selection: Train for dementia only, tumor only, or both simultaneously
  • Easy to Use: Keras-like API for quick training
  • Task-Specific Attention: Separate attention mechanisms for each task
  • MRI Artifact Simulation: Bias field, ghosting, spike noise, Rician noise
  • Elastic Deformation: Anatomical variation simulation for 2D and 3D
  • MixUp / CutMix: Modern batch-level augmentation techniques
  • AutoAugment: Automatic MRI-specific augmentation policy
  • HuggingFace Hub: Share and download models via HuggingFace
  • ONNX Export: Production deployment for both 2D and 3D models
  • Model Zoo: Registry of pretrained model configurations
  • Visualization: Built-in attention heatmap visualization
  • Configurable: YAML/JSON configuration support

Supported Classifications

2D - Dementia (6 classes):

  • AD Alzheimer's Disease
  • AD Mild Demented
  • AD Moderate Demented
  • AD Very Mild Demented
  • CN Non-Demented (Cognitively Normal)
  • PD Parkinson's Disease

2D - Brain Tumor (4 classes):

  • Glioma
  • Meningioma
  • No Tumor
  • Pituitary

3D - Alzheimer's (NIfTI):

  • CN (Cognitively Normal)
  • MCI (Mild Cognitive Impairment)
  • AD (Alzheimer's Disease)

Installation

# Basic installation
pip install vbai

# With NIfTI (3D) support
pip install vbai[nifti]

# With HuggingFace Hub integration
pip install vbai[hub]

# With ONNX export
pip install vbai[onnx]

# With all optional dependencies
pip install vbai[full]

# Development installation
git clone https://github.com/Neurazum-AI-Department/vbai.git
cd vbai
pip install -e .[dev]

Quick Start

2D Training (Dementia & Tumor)

import vbai

# Create model for both tasks (default)
model = vbai.MultiTaskBrainModel(variant='q')  # 'q' for quality, 'f' for fast

# Prepare dataset
dataset = vbai.UnifiedMRIDataset(
    dementia_path='./data/dementia/train',
    tumor_path='./data/tumor/train',
    is_training=True
)

# Create trainer and train
trainer = vbai.Trainer(model=model, lr=0.0005, device='cuda')
history = trainer.fit(train_data=dataset, epochs=10, batch_size=32)
trainer.save('brain_model.pt')

3D NIfTI Training (Volumetric)

import vbai

# Create 3D model
model = vbai.MultiTask3DBrainModel(
    variant='q',
    tasks={'alzheimer': 3},
    input_shape=(96, 96, 96)
)

# Create NIfTI dataset
dataset = vbai.NIfTIDataset(
    root='./data/alzheimer_3d',  # Subfolders: CN/, MCI/, AD/
    target_shape=(96, 96, 96),
    is_training=True
)

# Dataloaders
train_loader, val_loader = vbai.create_3d_dataloaders(
    root='./data/alzheimer_3d',
    batch_size=4, val_split=0.2
)

# Train
trainer = vbai.Trainer3D(model=model, lr=1e-4, device='cuda')
history = trainer.fit(train_loader, val_loader, epochs=25)
trainer.save('alzheimer_3d.pt')

Single-Task Training

import vbai

# Dementia only
model = vbai.MultiTaskBrainModel(variant='q', tasks=['dementia'])

# Tumor only
model = vbai.MultiTaskBrainModel(variant='q', tasks=['tumor'])

Making Predictions

import vbai

# 2D prediction
model = vbai.load('brain_model.pt', device='cuda')
result = model.predict('brain_scan.jpg')
print(f"Dementia: {result.dementia_class} ({result.dementia_confidence:.1%})")
print(f"Tumor: {result.tumor_class} ({result.tumor_confidence:.1%})")

# 3D NIfTI prediction
model_3d = vbai.load_3d('alzheimer_3d.pt', device='cuda')
result = model_3d.predict('scan.nii.gz', task='alzheimer',
                          class_names=['CN', 'MCI', 'AD'])
print(f"{result.predicted_class}: {result.confidence:.1%}")

HuggingFace Hub

import vbai

# List available models
models = vbai.list_models()        # All models
models_3d = vbai.list_models('3d') # Only 3D models

# Download and load from Hub
model = vbai.from_hub('Neurazum/vbai-3d-q', device='cuda')

# Push your trained model to Hub
model.push_to_hub('username/my-brain-model')

# Or use the functional API
vbai.push_to_hub(model, 'username/my-brain-model', private=True)

ONNX Export

import vbai

# Export 2D model
model_2d = vbai.MultiTaskBrainModel(variant='q')
model_2d.export_onnx('model_2d.onnx')

# Export 3D model
model_3d = vbai.MultiTask3DBrainModel(variant='q', tasks={'alzheimer': 3})
model_3d.export_onnx('model_3d.onnx')

# Or use the functional API
vbai.export_onnx(model_2d, 'model_2d.onnx')

# PyTorch-free inference with ONNX
onnx_model = vbai.ONNXModel('model_3d.onnx')
output = onnx_model.predict_nifti('brain_scan.nii.gz')
probs = onnx_model.softmax(output)

Data Augmentation

import vbai
import numpy as np

# ── MRI Artifact Simulation ──
volume = np.random.rand(96, 96, 96).astype(np.float32)

# Simulate individual artifacts
volume = vbai.simulate_bias_field(volume, intensity=0.3)
volume = vbai.simulate_ghosting(volume, num_ghosts=3, intensity=0.15)
volume = vbai.simulate_rician_noise(volume, std=0.03)
volume = vbai.simulate_spike_noise(volume, num_spikes=1, intensity=0.5)

# Or apply random artifacts in one call
volume = vbai.simulate_mri_artifacts(volume, p=0.5)

# ── Elastic Deformation ──
deformed_2d = vbai.elastic_deformation_2d(image_2d, alpha=50, sigma=5)
deformed_3d = vbai.elastic_deformation_3d(volume, alpha=30, sigma=4)

# ── MixUp / CutMix (batch-level, works with both 2D and 3D) ──
mixed, labels_a, labels_b, lam = vbai.mixup(images, labels, alpha=0.2)
loss = lam * criterion(model(mixed), labels_a) + (1-lam) * criterion(model(mixed), labels_b)

mixed, labels_a, labels_b, lam = vbai.cutmix(images, labels, alpha=1.0)

# ── AutoAugment (automatic MRI-specific policy) ──
augmenter = vbai.MRIAutoAugment(mode='3d', num_policies=10)
augmented_volume = augmenter(volume)

augmenter_2d = vbai.MRIAutoAugment(mode='2d', num_policies=10)
augmented_image = augmenter_2d(image_2d)

Using Callbacks

import vbai

model = vbai.MultiTaskBrainModel(variant='q')

callbacks = [
    vbai.EarlyStopping(monitor='val_loss', patience=5),
    vbai.ModelCheckpoint(
        filepath='checkpoints/model_{epoch:02d}.pt',
        monitor='val_loss',
        save_best_only=True
    )
]

trainer = vbai.Trainer(model=model, callbacks=callbacks)
trainer.fit(train_data, val_data, epochs=50)

Configuration

import vbai

# Use preset configurations
config = vbai.get_default_config('quality')  # 'default', 'fast', 'quality', 'debug'

# 3D configuration
config_3d = vbai.get_default_3d_config('quality')

# Custom config
model_config = vbai.ModelConfig(
    variant='q',
    tasks=['dementia', 'tumor'],
    dropout=0.3,
    use_edge_branch=True
)

Command Line Interface

# Train 2D model
vbai-train --dementia_path ./data/dementia --tumor_path ./data/tumor --epochs 10

# Single-task training
vbai-train --dementia_path ./data/dementia --tasks dementia --epochs 10

# Prediction
vbai-predict --model brain_model.pt --image brain_scan.jpg

# With visualization
vbai-predict --model brain_model.pt --image brain_scan.jpg --visualize --output result.png

Model Variants

2D Models

Variant Layers Channels Speed Accuracy
f (fast) 3 32-64-128 Fast Good
q (quality) 4 64-128-256-512 Slower Better

3D Models (NIfTI)

Variant Stages Channels Speed Accuracy
f (fast) 3x1 blocks 32-64-128 Fast Good
q (quality) 3x2 blocks 64-128-256 Slower Better

Dataset Structure

2D (RGB Images)

data/
├── dementia/
│   ├── train/
│   │   ├── AD_Alzheimer/
│   │   ├── AD_Mild_Demented/
│   │   ├── AD_Moderate_Demented/
│   │   ├── AD_Very_Mild_Demented/
│   │   ├── CN_Non_Demented/
│   │   └── PD_Parkinson/
│   └── val/
└── tumor/
    ├── train/
    │   ├── Glioma/
    │   ├── Meningioma/
    │   ├── No_Tumor/
    │   └── Pituitary/
    └── val/

3D (NIfTI Volumes)

data/alzheimer_3d/
├── CN/
│   ├── subject_001.nii.gz
│   └── subject_002.nii.gz
├── MCI/
│   └── subject_003.nii.gz
└── AD/
    └── subject_004.nii.gz

API Reference

Core Classes

  • MultiTaskBrainModel - 2D multi-task model (dementia + tumor)
  • MultiTask3DBrainModel - 3D volumetric model (NIfTI)
  • Trainer - 2D training loop manager
  • Trainer3D - 3D training loop manager

Data

  • UnifiedMRIDataset - 2D dataset (RGB images)
  • NIfTIDataset - 3D dataset (NIfTI volumes)
  • UnifiedNIfTIDataset - 3D multi-task dataset

Augmentation

  • simulate_bias_field() / simulate_ghosting() / simulate_spike_noise() / simulate_rician_noise() - MRI artifact simulation
  • simulate_mri_artifacts() - Combined random artifact application
  • elastic_deformation_2d() / elastic_deformation_3d() - Elastic deformation
  • mixup() / cutmix() - Batch-level augmentation (2D & 3D)
  • MRIAutoAugment - Automatic augmentation policy

Hub & Export

  • list_models() / get_model_info() - Model zoo registry
  • from_hub() / push_to_hub() - HuggingFace Hub integration
  • export_onnx() - ONNX export (2D & 3D)
  • ONNXModel - PyTorch-free ONNX inference

Configuration

  • ModelConfig / Model3DConfig - Architecture settings
  • TrainingConfig / Training3DConfig - Training hyperparameters
  • get_default_config() / get_default_3d_config() - Presets

Callbacks

  • EarlyStopping - Stop when no improvement
  • ModelCheckpoint - Save best/all checkpoints
  • TensorBoardLogger - Log to TensorBoard

Examples

See the examples/ directory:

  • train_basic.py - Basic 2D training
  • train_advanced.py - Advanced training with callbacks
  • train_3d.py - 3D NIfTI training
  • inference.py - Model inference

Citation

@software{vbai,
  title = {Vbai: Visual Brain AI Library},
  author = {Neurazum},
  year = {2025},
  url = {https://github.com/Neurazum-AI-Department/vbai}
}

License

MIT License - see LICENSE for details.

Contributing

Is being planned...

Support


Neurazum AI Department

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vbai-0.2.7.tar.gz (66.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vbai-0.2.7-py3-none-any.whl (72.0 kB view details)

Uploaded Python 3

File details

Details for the file vbai-0.2.7.tar.gz.

File metadata

  • Download URL: vbai-0.2.7.tar.gz
  • Upload date:
  • Size: 66.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.11

File hashes

Hashes for vbai-0.2.7.tar.gz
Algorithm Hash digest
SHA256 996b056fe2ef23ea90faa6595dfb24389582637b67be5dd6b4118f0740485a81
MD5 710a83b9fcbe0d4edaa0fa6beee637cd
BLAKE2b-256 fb17fe7542460ebe02a2131d34760228fa29de4e8b3179339bb45b8db38dcce2

See more details on using hashes here.

File details

Details for the file vbai-0.2.7-py3-none-any.whl.

File metadata

  • Download URL: vbai-0.2.7-py3-none-any.whl
  • Upload date:
  • Size: 72.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.11

File hashes

Hashes for vbai-0.2.7-py3-none-any.whl
Algorithm Hash digest
SHA256 d8e82c11558579d460e463539404f590dbbcf9702d3b940b56035d5e3d1f6c77
MD5 5585b97482e88ab3bbf8bf518aa943de
BLAKE2b-256 094f38332e6fe16910b2e0b4d6662cb39ce08d29b02bab0a4c09b17b4e94d761

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page