Skip to main content

Visual Brain AI - Multi-task brain MRI analysis library

Project description

Vbai - Visual Brain AI

Python 3.8+ PyTorch License: MIT

A PyTorch-based deep learning library for multi-task brain MRI analysis. Train models for dementia classification, brain tumor detection, or both simultaneously with just a few lines of code.

Features

  • Flexible Task Selection: Train for dementia only, tumor only, or both simultaneously
  • Easy to Use: Keras-like API for quick training
  • Task-Specific Attention: Separate attention mechanisms for each task
  • Visualization: Built-in attention heatmap visualization (adapts to active tasks)
  • Configurable: YAML/JSON configuration support
  • Production Ready: Export and deploy trained models

Supported Classifications

Dementia (6 classes):

  • AD Alzheimer's Disease
  • AD Mild Demented
  • AD Moderate Demented
  • AD Very Mild Demented
  • CN Non-Demented (Cognitively Normal)
  • PD Parkinson's Disease

Brain Tumor (4 classes):

  • Glioma
  • Meningioma
  • No Tumor
  • Pituitary

Installation

# Basic installation
pip install vbai

# With all optional dependencies
pip install vbai[full]

# Development installation
git clone https://github.com/Neurazum-AI-Department/vbai.git
cd vbai
pip install -e .[dev]

Quick Start

Training a Multi-Task Model (Both Dementia & Tumor)

import vbai

# Create model for both tasks (default)
model = vbai.MultiTaskBrainModel(variant='q')  # 'q' for quality, 'f' for fast

# Prepare dataset
dataset = vbai.UnifiedMRIDataset(
    dementia_path='./data/dementia/train',
    tumor_path='./data/tumor/train',
    is_training=True
)

# Create trainer
trainer = vbai.Trainer(
    model=model,
    lr=0.0005,
    device='cuda'
)

# Train
history = trainer.fit(
    train_data=dataset,
    epochs=10,
    batch_size=32
)

# Save model
trainer.save('brain_model.pt')

Training a Single-Task Model

import vbai

# Dementia only model
dementia_model = vbai.MultiTaskBrainModel(
    variant='q',
    tasks=['dementia']  # Only dementia classification
)

# Tumor only model
tumor_model = vbai.MultiTaskBrainModel(
    variant='q',
    tasks=['tumor']  # Only tumor detection
)

# Dataset for single task
dementia_dataset = vbai.UnifiedMRIDataset(
    dementia_path='./data/dementia/train',
    tumor_path=None,  # Not needed for dementia-only
    is_training=True
)

# Train as usual
trainer = vbai.Trainer(model=dementia_model, lr=0.0005)
trainer.fit(train_data=dementia_dataset, epochs=10)

Making Predictions

import vbai

# Load trained model
model = vbai.load('brain_model.pt', device='cuda')

# Single image prediction
result = model.predict('brain_scan.jpg')

# Multi-task model returns both predictions
if result.dementia_class:
    print(f"Dementia: {result.dementia_class} ({result.dementia_confidence:.1%})")
if result.tumor_class:
    print(f"Tumor: {result.tumor_class} ({result.tumor_confidence:.1%})")

# With attention visualization
result = model.predict('brain_scan.jpg', return_attention=True)
vis = vbai.VisualizationManager()
vis.visualize('brain_scan.jpg', result, save=True)
# Visualization adapts to show only active task panels

Using Callbacks

import vbai

model = vbai.MultiTaskBrainModel(variant='q')

# Setup callbacks
callbacks = [
    vbai.EarlyStopping(monitor='val_loss', patience=5),
    vbai.ModelCheckpoint(
        filepath='checkpoints/model_{epoch:02d}.pt',
        monitor='val_loss',
        save_best_only=True
    )
]

trainer = vbai.Trainer(model=model, callbacks=callbacks)
trainer.fit(train_data, val_data, epochs=50)

Configuration

import vbai

# Use preset configurations
config = vbai.get_default_config('quality')  # 'default', 'fast', 'quality', 'debug'

# Or customize
model_config = vbai.ModelConfig(
    variant='q',
    tasks=['dementia', 'tumor'],  # Choose which tasks to enable
    dropout=0.3,
    use_edge_branch=True
)

training_config = vbai.TrainingConfig(
    epochs=20,
    batch_size=16,
    lr=0.0001,
    scheduler='cosine'
)

# Save/Load configs
model_config.save('model_config.yaml')
loaded_config = vbai.ModelConfig.load('model_config.yaml')

Command Line Interface

Training

# Train both tasks
vbai-train --dementia_path ./data/dementia --tumor_path ./data/tumor --epochs 10

# Train dementia only
vbai-train --dementia_path ./data/dementia --tasks dementia --epochs 10

# Train tumor only
vbai-train --tumor_path ./data/tumor --tasks tumor --epochs 10

# Advanced options
vbai-train --dementia_path ./data/dementia --tumor_path ./data/tumor \
    --variant q --tasks dementia tumor --epochs 20 --batch_size 16 --lr 0.0001

Prediction

# Make prediction
vbai-predict --model brain_model.pt --image brain_scan.jpg

# With visualization
vbai-predict --model brain_model.pt --image brain_scan.jpg --visualize --output result.png

# JSON output
vbai-predict --model brain_model.pt --image brain_scan.jpg --json

Model Variants

Variant Layers Channels Speed Accuracy
f (fast) 3 32-64-128 Fast Good
q (quality) 4 64-128-256-512 Slower Better

Task Selection

Tasks Parameter Description Use Case
['dementia', 'tumor'] Both tasks (default) Multi-task learning
['dementia'] Dementia only Specialized dementia detection
['tumor'] Tumor only Specialized tumor detection

Dataset Structure

Your dataset should be organized as follows:

data/
├── dementia/
│   ├── train/
│   │   ├── AD_Alzheimer/
│   │   ├── AD_Mild_Demented/
│   │   ├── AD_Moderate_Demented/
│   │   ├── AD_Very_Mild_Demented/
│   │   ├── CN_Non_Demented/
│   │   └── PD_Parkinson/
│   └── val/
│       └── ...
└── tumor/
    ├── train/
    │   ├── Glioma/
    │   ├── Meningioma/
    │   ├── No_Tumor/
    │   └── Pituitary/
    └── val/
        └── ...

Note: You only need the dataset for the task(s) you're training. For single-task training, only the relevant dataset directory is required.

API Reference

Core Classes

  • MultiTaskBrainModel - Main model class (supports single and multi-task)
  • UnifiedMRIDataset - Dataset for training (handles missing task data)
  • Trainer - Training loop manager
  • VisualizationManager - Attention heatmap visualization (adapts to active tasks)

Configuration

  • ModelConfig - Model architecture settings (includes tasks parameter)
  • TrainingConfig - Training hyperparameters
  • get_default_config() - Preset configurations

Callbacks

  • EarlyStopping - Stop when no improvement
  • ModelCheckpoint - Save best/all checkpoints
  • TensorBoardLogger - Log to TensorBoard

Examples

See the examples/ directory for complete examples:

  • train_basic.py - Basic training example
  • train_advanced.py - Advanced training with callbacks
  • inference.py - Model inference

Citation

If you use Vbai in your research, please cite:

@software{vbai,
  title = {Vbai: Visual Brain AI Library},
  author = {Neurazum},
  year = {2025},
  url = {https://github.com/Neurazum-AI-Department/vbai}
}

License

MIT License - see LICENSE for details.

Contributing

Is being planned...

Support


Neurazum AI Department

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vbai-0.1.9.tar.gz (38.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vbai-0.1.9-py3-none-any.whl (38.0 kB view details)

Uploaded Python 3

File details

Details for the file vbai-0.1.9.tar.gz.

File metadata

  • Download URL: vbai-0.1.9.tar.gz
  • Upload date:
  • Size: 38.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.11

File hashes

Hashes for vbai-0.1.9.tar.gz
Algorithm Hash digest
SHA256 639a0999643d5bf74163a5032f8ea602443198035c26dcfabcd23f56a655bbde
MD5 17f30ffa8e07c6a772c495d5b459df83
BLAKE2b-256 bc1f96254c3a75381e8ff701e057600e1af5c49f0d3e369c90e44b877ee2ce05

See more details on using hashes here.

File details

Details for the file vbai-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: vbai-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 38.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.11

File hashes

Hashes for vbai-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 ad4e1246f3e3ff748f81897eadee75d4aa4b1f95325c284b7c609bc8f98063fe
MD5 c1b21bfadf53d94ae42cf7a27c4ac389
BLAKE2b-256 f53f484b6e96cfe056970685e425465e8afb0c4054ab5605333b249895aec6ba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page