A compagnon library for deep learning on medical imaging
Project description
TorchMed
Read and process medical images in PyTorch.
This library is designed as a flexible tool to process various types N dimension images. Through a set of image readers based on famous projects (SimpleITK, NiBabel, OpenCV, Pillow) you will be able to load your data. Once loaded, specific sub-sampling of the original data is performed with patterns (describing what/how to extract) and samplers (checks where to extract).
With readers, samplers and patterns, you can compose datasets which a perfectly suited for PyTorch.
Install
From pip:
pip install torchmed
Locally :
python install setup.py
Usage
Reader
>>> import torchmed
>>> image = torchmed.readers.SitkReader('prepro_im_mni_bc.nii.gz')
>>> label_map = torchmed.readers.SitkReader('prepro_seg_mni.nii.gz')
# gets image data
>>> image_array = image.to_torch()
>>> label_array = label_map.to_torch()
>>> image_array.size()
torch.Size([182, 218, 182])
>>> type(image_array)
<class 'torch.Tensor'>
>>> label_array[0,0,0]
tensor(0.)
# also available for Numpy
>>> type(image.to_numpy())
<class 'numpy.ndarray'>
Pattern
Patterns are useful to specify how the data should be extracted from an image. It is possible to apply several patterns on one or more images.
>>> import torchmed
>>> image = torchmed.readers.SitkReader('prepro_im_mni_bc.nii.gz')
>>> square_patch = torchmed.patterns.SquaredSlidingWindow([182, 7, 182], use_padding=False)
# initialize the pattern with the image properties
>>> square_patch.prepare(image_arr)
# can_apply checks if a pattern can be applied at a given position
>>> square_patch.can_apply(image_arr, [0,0,0])
False
>>> square_patch.can_apply(image_arr, [91,4,91])
True
>>> square_patch.can_apply(image_arr, [91,3,91])
True
>>> square_patch.can_apply(image_arr, [91,2,91])
False
>>> square_patch.can_apply(image_arr, [91,154,91])
True
# to extract a patch at a correct position
>>> sample = square_patch(image_arr, [91,154,91])
>>> sample.size()
torch.Size([182, 7, 182])
Sampler
Multi-processed sampler to automatically search for coordinates where sampling (pattern extraction) is possible.
>>> from torchmed.readers import SitkReader
>>> from torchmed.samplers import MaskableSampler
>>> from torchmed.patterns import SquaredSlidingWindow
# maps a name to each image
>>> file_map = {
... 'image_ref': SitkReader('prepro_im_mni_bc.nii.gz',
... torch_type='torch.FloatTensor'),
... 'target': SitkReader('prepro_seg_mni.nii.gz',
... torch_type='torch.LongTensor')
... }
# sliding window pattern
>>> patch2d = SquaredSlidingWindow(patch_size=[182, 7, 182], use_padding=False)
# specify a pattern for each input image
>>> pattern_mapper = {'input': ('image_ref', patch2d),
... 'target': ('target', patch2d)}
# muli-processed sampler with offset
>>> sampler = MaskableSampler(pattern_mapper, offset=[91, 1, 91], nb_workers=4)
>>> sampler.build(file_map)
>>> len(sampler)
212
>>> sample = sampler[0]
>>> type(sample)
<class 'tuple'>
>>> sample[0].size()
torch.Size([3])
>>> sample[1].size()
torch.Size([182, 7, 182])
>>> sample[2].size()
torch.Size([182, 7, 182])
Dataset
MedFile
and MedFolder
are iterable datasets, returning samples from the input
data. Here is an example of how to build a MedFolder
from a list of images.
A MedFolder
takes as input a list of MedFile
s.
import os
from torchmed.datasets import MedFile, MedFolder
self.train_dataset = MedFolder(
self.generate_medfiles(os.path.join(base_dir, 'train'), nb_workers))
def generate_medfiles(self, dir, nb_workers):
# database composed of dirname contained in the allowed_data.txt
database = open(os.path.join(dir, 'allowed_data.txt'), 'r')
patient_list = [line.rstrip('\n') for line in database]
medfiles = []
# builds a list of MedFiles, one for each folder
for patient in patient_list:
if patient:
patient_dir = os.path.join(dir, patient)
patient_data = self.build_patient_data_map(patient_dir)
patient_file = MedFile(patient_data, self.build_sampler(nb_workers))
medfiles.append(patient_file)
return medfiles
def build_patient_data_map(self, dir):
# pads each dimension of the image on both sides.
pad_reflect = Pad(((1, 1), (3, 3), (1, 1)), 'reflect')
file_map = {
'image_ref': SitkReader(
os.path.join(dir, 'prepro_im_mni_bc.nii.gz'),
torch_type='torch.FloatTensor', transform=pad_reflect),
'target': SitkReader(
os.path.join(dir, 'prepro_seg_mni.nii.gz'),
torch_type='torch.LongTensor', transform=pad_reflect)
}
return file_map
def build_sampler(self, nb_workers):
# sliding window of size [184, 7, 184] without padding
patch2d = SquaredSlidingWindow(patch_size=[184, 7, 184], use_padding=False)
# pattern map links image id to a Sampler
pattern_mapper = {'input': ('image_ref', patch2d),
'target': ('target', patch2d)}
# add a fixed offset to make patch sampling faster (doesn't look for all positions)
return MaskableSampler(pattern_mapper, offset=[92, 1, 92],
nb_workers=nb_workers)
Examples
See the datasets
folder of the examples for a more pratical use case.
Credits
Evaluation metrics are mostly based on MedPy.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchmed-0.0.1a0.tar.gz
.
File metadata
- Download URL: torchmed-0.0.1a0.tar.gz
- Upload date:
- Size: 24.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.29.1 CPython/3.6.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4ab7744b6e164aa718e6cb83a3778b7952d2b0b6a108033a93f2e58dd078b269 |
|
MD5 | 7897f126bda4d41b1c429b323f0da8a3 |
|
BLAKE2b-256 | 9661fca68f16686aa44781724f5f7a4ebce054340d4dc68296dc1888c7684906 |
File details
Details for the file torchmed-0.0.1a0-py3-none-any.whl
.
File metadata
- Download URL: torchmed-0.0.1a0-py3-none-any.whl
- Upload date:
- Size: 34.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.29.1 CPython/3.6.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9ce70dbe19917bb3fa8795e37a45915c56ba9f07b0d596561555feee6c0367df |
|
MD5 | bc7983fd2dec09bb564cc046813ded14 |
|
BLAKE2b-256 | 203a3d3320e0e839bb4a8715ae15e610ccaa348c01b82381ffe7b253d463e105 |