Skip to main content

FOD-Net Reimplementation.

Project description

FODNet

FOD-Net reimplementation with training and inference pipeline. This module uses the FODNet model originally implemented here.

If you use this code for your research, please cite:

FOD-Net: A Deep Learning Method for Fiber Orientation Distribution Angular Super Resolution.
Rui Zeng, Jinglei Lv, He Wang, Luping Zhou, Michael Barnett, Fernando Calamante*, Chenyu Wang*. In Medical Image Analysis. (* equal contributions) [Bibtex].

Requirements

This module requires the following python packages:

  • torch >= 2.0.0
  • lightning >= 2.0.0
  • numpy
  • einops
  • npy-patcher
  • scikit-image
  • nibabel

These will be installed upon installation of this package, however it is recommended to follow the instructions for installing PyTorch independently before installing this package, to ensure correct hardware optimizations are enabled.

Installation

pip install fodnet

Training

Follow the instructions below on how to train the FODNet model.

Data Preprocessing

This training pipeline requires data to be saved in .npy format. Additionally the spherical harmonic dimension must be the first dimension within each 4D array. This is because this module uses npy-patcher to extract training patches at runtime. Below is an example on how to convert NIfTI files into .npy using nibabel.

import numpy as np
import nibabel as nib

img = nib.load('/path/to/fod.nii.gz')
data = np.asarray(img.dataobj, dtype=np.float32)  # Load FOD data into memory
data = data.transpose(3, 0, 1, 2)  # Move the SH dimension to 0
np.save('/path/to/fod.npy', data, allow_pickle=False)  # Save in npy format. Ensure this is on an SSD.

N.B. Patches are read lazily from disk, therefore it is highly recommended to store the training data on an SSD type device, as an HDD will bottleneck the training process when data loading.

Training

import lightning.pytorch as pl

from fodnet.core.model import FODNetLightningModel
from fodnet.core.dataset import Subject, FODNetDataModule

# Collect dataset filepaths
subj1 = Subject('/path/to/lowres_fod1.npy', '/path/to/highresres_fod1.npy', '/path/to/mask1.npy')
subj2 = Subject('/path/to/lowres_fod2.npy', '/path/to/highresres_fod2.npy', '/path/to/mask2.npy')
subj3 = Subject('/path/to/lowres_fod3.npy', '/path/to/highresres_fod3.npy', '/path/to/mask3.npy')

# Create DataModule instance. This is a thin wrapper around `pl.LightningDataModule`.
data_module = FODNetDataModule(
    train_subjects=(subj1, subj2),
    val_subjects=(subj3),
    batch_size=16, # Batch size of each device
    num_workers=8, # Number of CPU workers that load the data
)

# Load FODNet lightning model
model = FODNetLightningModel()

# Create `pl.Trainer` instance. `FODNetDataModule` is usable in DDP distributed training strategy.
trainer = pl.Trainer(devices=1, accelerator='gpu', epochs=100)

# Start training
trainer.fit(model, data_module)

Customization

This implemenation uses a different training optimizer, loss, and learning rate than that used in the original implementation. In particular we use AdamW, L1 Loss, and 0.003 respectively.

Changing these hyperparameters is straightforward. Simply create a new class that inherits the FODNetLightningModel, and modify the properties/methods below. Use this class instead of FODNetLightningModel when training.

class MyCustomModel(FODNetLightningModel):

    @property
    def loss_func(self):
        '''Different loss function'''
        return torch.nn.functional.mse_loss
    
    def configure_optimizers(self):
        '''Different Optimizer and learning rate'''
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-5)
        return optimizer

Prediction

from fodnet.core.model import FODNetLightningModel
from fodnet.core.prediction import FODNetPredictionProcessor

model = FODNetLightningModel.load_from_checkpoint('/path/to/my/checkpoint.ckpt')
predict = FODNetPredictionProcessor(batch_size=32, num_workers=8, accelerator='gpu')
predict.run_subject(
    model,
    '/path/to/my/brainmask.nii.gz',
    '/path/to/lowres_fod.nii.gz',
    '/path/to/dest/highres_fod.nii.gz',
    tmp_dir=None,  # Optionally specify a temporary directory to save the FOD file during processing
)

N.B. Patches are read lazily from disk, therefore it is recommended to ensure tmp_dir is on an SSD type device.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fodnet-1.1.0.tar.gz (13.2 kB view hashes)

Uploaded Source

Built Distribution

fodnet-1.1.0-py3-none-any.whl (13.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page