Skip to main content

A compagnon library for deep learning on medical imaging

Project description


Read and process medical images in PyTorch.

Build Status codecov

This library is designed as a flexible tool to process various types N dimension images. Through a set of image readers based on famous projects (SimpleITK, NiBabel, OpenCV, Pillow) you will be able to load your data. Once loaded, specific sub-sampling of the original data is performed with patterns (describing what/how to extract) and samplers (checks where to extract).

With readers, samplers and patterns, you can compose datasets which a perfectly suited for PyTorch.


From pip:

pip install torchmed

Locally :

python install



>>> import torchmed

>>> image = torchmed.readers.SitkReader('prepro_im_mni_bc.nii.gz')
>>> label_map = torchmed.readers.SitkReader('prepro_seg_mni.nii.gz')
# gets image data
>>> image_array = image.to_torch()
>>> label_array = label_map.to_torch()

>>> image_array.size()
torch.Size([182, 218, 182])
>>> type(image_array)
<class 'torch.Tensor'>
>>> label_array[0,0,0]
# also available for Numpy
>>> type(image.to_numpy())
<class 'numpy.ndarray'>


Patterns are useful to specify how the data should be extracted from an image. It is possible to apply several patterns on one or more images.

>>> import torchmed

>>> image = torchmed.readers.SitkReader('prepro_im_mni_bc.nii.gz')
>>> square_patch = torchmed.patterns.SquaredSlidingWindow([182, 7, 182], use_padding=False)
# initialize the pattern with the image properties
>>> square_patch.prepare(image_arr)

# can_apply checks if a pattern can be applied at a given position
>>> square_patch.can_apply(image_arr, [0,0,0])
>>> square_patch.can_apply(image_arr, [91,4,91])
>>> square_patch.can_apply(image_arr, [91,3,91])
>>> square_patch.can_apply(image_arr, [91,2,91])
>>> square_patch.can_apply(image_arr, [91,154,91])

# to extract a patch at a correct position
>>> sample = square_patch(image_arr, [91,154,91])
>>> sample.size()
torch.Size([182, 7, 182])


Multi-processed sampler to automatically search for coordinates where sampling (pattern extraction) is possible.

>>> from torchmed.readers import SitkReader
>>> from torchmed.samplers import MaskableSampler
>>> from torchmed.patterns import SquaredSlidingWindow

# maps a name to each image
>>> file_map = {
...         'image_ref': SitkReader('prepro_im_mni_bc.nii.gz',
...             torch_type='torch.FloatTensor'),
...         'target': SitkReader('prepro_seg_mni.nii.gz',
...             torch_type='torch.LongTensor')
...     }

# sliding window pattern
>>> patch2d = SquaredSlidingWindow(patch_size=[182, 7, 182], use_padding=False)
# specify a pattern for each input image
>>> pattern_mapper = {'input': ('image_ref', patch2d),
...                   'target': ('target', patch2d)}
# muli-processed sampler with offset
>>> sampler = MaskableSampler(pattern_mapper, offset=[91, 1, 91], nb_workers=4)
>>> len(sampler)
>>> sample = sampler[0]
>>> type(sample)
<class 'tuple'>
>>> sample[0].size()
>>> sample[1].size()
torch.Size([182, 7, 182])
>>> sample[2].size()
torch.Size([182, 7, 182])


MedFile and MedFolder are iterable datasets, returning samples from the input data. Here is an example of how to build a MedFolder from a list of images. A MedFolder takes as input a list of MedFiles.

import os
from torchmed.datasets import MedFile, MedFolder

self.train_dataset = MedFolder(
        self.generate_medfiles(os.path.join(base_dir, 'train'), nb_workers))

def generate_medfiles(self, dir, nb_workers):
      # database composed of dirname contained in the allowed_data.txt
      database = open(os.path.join(dir, 'allowed_data.txt'), 'r')
      patient_list = [line.rstrip('\n') for line in database]
      medfiles = []

      # builds a list of MedFiles, one for each folder
      for patient in patient_list:
          if patient:
              patient_dir = os.path.join(dir, patient)
              patient_data = self.build_patient_data_map(patient_dir)
              patient_file = MedFile(patient_data, self.build_sampler(nb_workers))

      return medfiles

def build_patient_data_map(self, dir):
      # pads each dimension of the image on both sides.
      pad_reflect = Pad(((1, 1), (3, 3), (1, 1)), 'reflect')
      file_map = {
          'image_ref': SitkReader(
              os.path.join(dir, 'prepro_im_mni_bc.nii.gz'),
              torch_type='torch.FloatTensor', transform=pad_reflect),
          'target': SitkReader(
              os.path.join(dir, 'prepro_seg_mni.nii.gz'),
              torch_type='torch.LongTensor', transform=pad_reflect)

      return file_map

def build_sampler(self, nb_workers):
    # sliding window of size [184, 7, 184] without padding
    patch2d = SquaredSlidingWindow(patch_size=[184, 7, 184], use_padding=False)
    # pattern map links image id to a Sampler
    pattern_mapper = {'input': ('image_ref', patch2d),
                      'target': ('target', patch2d)}

    # add a fixed offset to make patch sampling faster (doesn't look for all positions)
    return MaskableSampler(pattern_mapper, offset=[92, 1, 92],


See the datasets folder of the examples for a more pratical use case.


Evaluation metrics are mostly based on MedPy.

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for torchmed, version 0.0.1a0
Filename, size File type Python version Upload date Hashes
Filename, size torchmed-0.0.1a0.tar.gz (24.4 kB) File type Source Python version None Upload date Hashes View
Filename, size torchmed-0.0.1a0-py3-none-any.whl (34.1 kB) File type Wheel Python version py3 Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page