FOD-Net Reimplementation.
Project description
FODNet
FOD-Net reimplementation with training and inference pipeline. This module uses the FODNet model originally implemented here.
If you use this code for your research, please cite:
FOD-Net: A Deep Learning Method for Fiber Orientation Distribution Angular Super Resolution.
Rui Zeng, Jinglei Lv, He Wang, Luping Zhou, Michael Barnett, Fernando Calamante*, Chenyu Wang*. In Medical Image Analysis. (* equal contributions) [Bibtex].
Requirements
This module requires the following python packages:
torch >= 2.0.0
lightning >= 2.0.0
numpy
einops
npy-patcher
scikit-image
nibabel
These will be installed upon installation of this package, however it is recommended to follow the instructions for installing PyTorch independently before installing this package, to ensure correct hardware optimizations are enabled.
Installation
pip install fodnet
Training
Follow the instructions below on how to train the FODNet model.
Data Preprocessing
This training pipeline requires data to be saved in .npy
format. Additionally the spherical harmonic dimension must be the first dimension within each 4D array. This is because this module uses npy-patcher to extract training patches at runtime. Below is an example on how to convert NIfTI
files into .npy
using nibabel.
import numpy as np
import nibabel as nib
img = nib.load('/path/to/fod.nii.gz')
data = np.asarray(img.dataobj, dtype=np.float32) # Load FOD data into memory
data = data.transpose(3, 0, 1, 2) # Move the SH dimension to 0
np.save('/path/to/fod.npy', data, allow_pickle=False) # Save in npy format. Ensure this is on an SSD.
N.B. Patches are read lazily from disk, therefore it is highly recommended to store the training data on an SSD type device, as an HDD will bottleneck the training process when data loading.
Training
import lightning.pytorch as pl
from fodnet.core.model import FODNetLightningModel
from fodnet.core.dataset import Subject, FODNetDataModule
# Collect dataset filepaths
subj1 = Subject('/path/to/lowres_fod1.npy', '/path/to/highresres_fod1.npy', '/path/to/mask1.npy')
subj2 = Subject('/path/to/lowres_fod2.npy', '/path/to/highresres_fod2.npy', '/path/to/mask2.npy')
subj3 = Subject('/path/to/lowres_fod3.npy', '/path/to/highresres_fod3.npy', '/path/to/mask3.npy')
# Create DataModule instance. This is a thin wrapper around `pl.LightningDataModule`.
data_module = FODNetDataModule(
train_subjects=(subj1, subj2),
val_subjects=(subj3),
batch_size=16, # Batch size of each device
num_workers=8, # Number of CPU workers that load the data
)
# Load FODNet lightning model
model = FODNetLightningModel()
# Create `pl.Trainer` instance. `FODNetDataModule` is usable in DDP distributed training strategy.
trainer = pl.Trainer(devices=1, accelerator='gpu', epochs=100)
# Start training
trainer.fit(model, data_module)
Customization
This implemenation uses a different training optimizer, loss, and learning rate than that used in the original implementation. In particular we use AdamW
, L1 Loss
, and 0.003
respectively.
Changing these hyperparameters is straightforward. Simply create a new class that inherits the FODNetLightningModel
, and modify the properties/methods below. Use this class instead of FODNetLightningModel
when training.
class MyCustomModel(FODNetLightningModel):
@property
def loss_func(self):
'''Different loss function'''
return torch.nn.functional.mse_loss
def configure_optimizers(self):
'''Different Optimizer and learning rate'''
optimizer = torch.optim.SGD(self.parameters(), lr=1e-5)
return optimizer
Prediction
from fodnet.core.model import FODNetLightningModel
from fodnet.core.prediction import FODNetPredictionProcessor
model = FODNetLightningModel.load_from_checkpoint('/path/to/my/checkpoint.ckpt')
predict = FODNetPredictionProcessor(batch_size=32, num_workers=8, accelerator='gpu')
predict.run_subject(
model,
'/path/to/my/brainmask.nii.gz',
'/path/to/lowres_fod.nii.gz',
'/path/to/dest/highres_fod.nii.gz',
tmp_dir=None, # Optionally specify a temporary directory to save the FOD file during processing
)
N.B. Patches are read lazily from disk, therefore it is recommended to ensure tmp_dir
is on an SSD type device.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.