Skip to main content

FOD-Net Reimplementation.

Project description

FODNet

FOD-Net reimplementation with training and inference pipeline. This module uses the FODNet model originally implemented here.

If you use this code for your research, please cite:

FOD-Net: A Deep Learning Method for Fiber Orientation Distribution Angular Super Resolution.
Rui Zeng, Jinglei Lv, He Wang, Luping Zhou, Michael Barnett, Fernando Calamante*, Chenyu Wang*. In Medical Image Analysis. (* equal contributions) [Bibtex].

Requirements

This module requires the following python packages:

  • torch >= 2.0.0
  • lightning >= 2.0.0
  • numpy
  • einops
  • npy-patcher
  • scikit-image
  • nibabel

These will be installed upon installation of this package, however it is recommended to follow the instructions for installing PyTorch independently before installing this package, to ensure correct hardware optimizations are enabled.

Installation

pip install fodnet

Training

Follow the instructions below on how to train the FODNet model.

Data Preprocessing

This training pipeline requires data to be saved in .npy format. Additionally the spherical harmonic dimension must be the first dimension within each 4D array. This is because this module uses npy-patcher to extract training patches at runtime. Below is an example on how to convert NIfTI files into .npy using nibabel.

import numpy as np
import nibabel as nib

img = nib.load('/path/to/fod.nii.gz')
data = np.asarray(img.dataobj, dtype=np.float32)  # Load FOD data into memory
data = data.transpose(3, 0, 1, 2)  # Move the SH dimension to 0
np.save('/path/to/fod.npy', data, allow_pickle=False)  # Save in npy format. Ensure this is on an SSD.

N.B. Patches are read lazily from disk, therefore it is highly recommended to store the training data on an SSD type device, as an HDD will bottleneck the training process when data loading.

Training

import lightning.pytorch as pl

from fodnet.core.model import FODNetLightningModel
from fodnet.core.dataset import Subject, FODNetDataModule

# Collect dataset filepaths
subj1 = Subject('/path/to/lowres_fod1.npy', '/path/to/highresres_fod1.npy', '/path/to/mask1.npy')
subj2 = Subject('/path/to/lowres_fod2.npy', '/path/to/highresres_fod2.npy', '/path/to/mask2.npy')
subj3 = Subject('/path/to/lowres_fod3.npy', '/path/to/highresres_fod3.npy', '/path/to/mask3.npy')

# Create DataModule instance. This is a thin wrapper around `pl.LightningDataModule`.
data_module = FODNetDataModule(
    train_subjects=(subj1, subj2),
    val_subjects=(subj3),
    batch_size=16, # Batch size of each device
    num_workers=8, # Number of CPU workers that load the data
)

# Load FODNet lightning model
model = FODNetLightningModel()

# Create `pl.Trainer` instance. `FODNetDataModule` is usable in DDP distributed training strategy.
trainer = pl.Trainer(devices=1, accelerator='gpu', epochs=100)

# Start training
trainer.fit(model, data_module)

Customization

This implemenation uses a different training optimizer, loss, and learning rate than that used in the original implementation. In particular we use AdamW, L1 Loss, and 0.003 respectively.

Changing these hyperparameters is straightforward. Simply create a new class that inherits the FODNetLightningModel, and modify the properties/methods below. Use this class instead of FODNetLightningModel when training.

class MyCustomModel(FODNetLightningModel):

    @property
    def loss_func(self):
        '''Different loss function'''
        return torch.nn.functional.mse_loss
    
    def configure_optimizers(self):
        '''Different Optimizer and learning rate'''
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-5)
        return optimizer

Prediction

from fodnet.core.model import FODNetLightningModel
from fodnet.core.prediction import FODNetPredictionProcessor

model = FODNetLightningModel.load_from_checkpoint('/path/to/my/checkpoint.ckpt')
predict = FODNetPredictionProcessor(batch_size=32, num_workers=8, accelerator='gpu')
predict.run_subject(
    model,
    '/path/to/my/brainmask.nii.gz',
    '/path/to/lowres_fod.nii.gz',
    '/path/to/dest/highres_fod.nii.gz',
    tmp_dir=None,  # Optionally specify a temporary directory to save the FOD file during processing
)

N.B. Patches are read lazily from disk, therefore it is recommended to ensure tmp_dir is on an SSD type device.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fodnet-1.1.0.tar.gz (13.2 kB view details)

Uploaded Source

Built Distribution

fodnet-1.1.0-py3-none-any.whl (13.3 kB view details)

Uploaded Python 3

File details

Details for the file fodnet-1.1.0.tar.gz.

File metadata

  • Download URL: fodnet-1.1.0.tar.gz
  • Upload date:
  • Size: 13.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for fodnet-1.1.0.tar.gz
Algorithm Hash digest
SHA256 370ace5f4496ef5a7b0e9d08805776c16f677e1a4ef55e5a78263d001ca34dc8
MD5 dfc30e3f97639276aaa011caebea8d91
BLAKE2b-256 75b92b660bbddbe3554cbcfa6aa86296ce609d47da3d6957af0107523d1db392

See more details on using hashes here.

File details

Details for the file fodnet-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: fodnet-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 13.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for fodnet-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 92acd2126c113c5835ade2cf5b4dc934e2e14841502a81232c8a356a78393dbb
MD5 794a890155bc09be649fa280040fc1bb
BLAKE2b-256 642cf95eb5f722ad434d49e9c08576754a6ce58e4246534cdc81014226f6384d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page