Skip to main content

A compagnon library for deep learning on medical imaging

Project description

TorchMed

Read and process medical images in PyTorch.

Build Status codecov


This library is designed as a flexible tool to process various types N dimension images. Through a set of image readers based on famous projects (SimpleITK, NiBabel, OpenCV, Pillow) you will be able to load your data. Once loaded, specific sub-sampling of the original data is performed with patterns (describing what/how to extract) and samplers (checks where to extract).

With readers, samplers and patterns, you can compose datasets which a perfectly suited for PyTorch.

Install

From pip:

pip install torchmed

Locally :

python install setup.py

Usage

Reader

>>> import torchmed

>>> image = torchmed.readers.SitkReader('prepro_im_mni_bc.nii.gz')
>>> label_map = torchmed.readers.SitkReader('prepro_seg_mni.nii.gz')
# gets image data
>>> image_array = image.to_torch()
>>> label_array = label_map.to_torch()

>>> image_array.size()
torch.Size([182, 218, 182])
>>> type(image_array)
<class 'torch.Tensor'>
>>> label_array[0,0,0]
tensor(0.)
# also available for Numpy
>>> type(image.to_numpy())
<class 'numpy.ndarray'>

Pattern

Patterns are useful to specify how the data should be extracted from an image. It is possible to apply several patterns on one or more images.

>>> import torchmed

>>> image = torchmed.readers.SitkReader('prepro_im_mni_bc.nii.gz')
>>> square_patch = torchmed.patterns.SquaredSlidingWindow([182, 7, 182], use_padding=False)
# initialize the pattern with the image properties
>>> square_patch.prepare(image_arr)

# can_apply checks if a pattern can be applied at a given position
>>> square_patch.can_apply(image_arr, [0,0,0])
False
>>> square_patch.can_apply(image_arr, [91,4,91])
True
>>> square_patch.can_apply(image_arr, [91,3,91])
True
>>> square_patch.can_apply(image_arr, [91,2,91])
False
>>> square_patch.can_apply(image_arr, [91,154,91])
True

# to extract a patch at a correct position
>>> sample = square_patch(image_arr, [91,154,91])
>>> sample.size()
torch.Size([182, 7, 182])

Sampler

Multi-processed sampler to automatically search for coordinates where sampling (pattern extraction) is possible.

>>> from torchmed.readers import SitkReader
>>> from torchmed.samplers import MaskableSampler
>>> from torchmed.patterns import SquaredSlidingWindow

# maps a name to each image
>>> file_map = {
...         'image_ref': SitkReader('prepro_im_mni_bc.nii.gz',
...             torch_type='torch.FloatTensor'),
...         'target': SitkReader('prepro_seg_mni.nii.gz',
...             torch_type='torch.LongTensor')
...     }

# sliding window pattern
>>> patch2d = SquaredSlidingWindow(patch_size=[182, 7, 182], use_padding=False)
# specify a pattern for each input image
>>> pattern_mapper = {'input': ('image_ref', patch2d),
...                   'target': ('target', patch2d)}
# muli-processed sampler with offset
>>> sampler = MaskableSampler(pattern_mapper, offset=[91, 1, 91], nb_workers=4)
>>> sampler.build(file_map)
>>> len(sampler)
212
>>> sample = sampler[0]
>>> type(sample)
<class 'tuple'>
>>> sample[0].size()
torch.Size([3])
>>> sample[1].size()
torch.Size([182, 7, 182])
>>> sample[2].size()
torch.Size([182, 7, 182])

Dataset

MedFile and MedFolder are iterable datasets, returning samples from the input data. Here is an example of how to build a MedFolder from a list of images. A MedFolder takes as input a list of MedFiles.

import os
from torchmed.datasets import MedFile, MedFolder

self.train_dataset = MedFolder(
        self.generate_medfiles(os.path.join(base_dir, 'train'), nb_workers))

def generate_medfiles(self, dir, nb_workers):
      # database composed of dirname contained in the allowed_data.txt
      database = open(os.path.join(dir, 'allowed_data.txt'), 'r')
      patient_list = [line.rstrip('\n') for line in database]
      medfiles = []

      # builds a list of MedFiles, one for each folder
      for patient in patient_list:
          if patient:
              patient_dir = os.path.join(dir, patient)
              patient_data = self.build_patient_data_map(patient_dir)
              patient_file = MedFile(patient_data, self.build_sampler(nb_workers))
              medfiles.append(patient_file)

      return medfiles

def build_patient_data_map(self, dir):
      # pads each dimension of the image on both sides.
      pad_reflect = Pad(((1, 1), (3, 3), (1, 1)), 'reflect')
      file_map = {
          'image_ref': SitkReader(
              os.path.join(dir, 'prepro_im_mni_bc.nii.gz'),
              torch_type='torch.FloatTensor', transform=pad_reflect),
          'target': SitkReader(
              os.path.join(dir, 'prepro_seg_mni.nii.gz'),
              torch_type='torch.LongTensor', transform=pad_reflect)
      }

      return file_map

def build_sampler(self, nb_workers):
    # sliding window of size [184, 7, 184] without padding
    patch2d = SquaredSlidingWindow(patch_size=[184, 7, 184], use_padding=False)
    # pattern map links image id to a Sampler
    pattern_mapper = {'input': ('image_ref', patch2d),
                      'target': ('target', patch2d)}

    # add a fixed offset to make patch sampling faster (doesn't look for all positions)
    return MaskableSampler(pattern_mapper, offset=[92, 1, 92],
                           nb_workers=nb_workers)

Examples

See the datasets folder of the examples for a more pratical use case.

Credits

Evaluation metrics are mostly based on MedPy.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmed-0.0.1a0.tar.gz (24.4 kB view details)

Uploaded Source

Built Distribution

torchmed-0.0.1a0-py3-none-any.whl (34.1 kB view details)

Uploaded Python 3

File details

Details for the file torchmed-0.0.1a0.tar.gz.

File metadata

  • Download URL: torchmed-0.0.1a0.tar.gz
  • Upload date:
  • Size: 24.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.29.1 CPython/3.6.6

File hashes

Hashes for torchmed-0.0.1a0.tar.gz
Algorithm Hash digest
SHA256 4ab7744b6e164aa718e6cb83a3778b7952d2b0b6a108033a93f2e58dd078b269
MD5 7897f126bda4d41b1c429b323f0da8a3
BLAKE2b-256 9661fca68f16686aa44781724f5f7a4ebce054340d4dc68296dc1888c7684906

See more details on using hashes here.

File details

Details for the file torchmed-0.0.1a0-py3-none-any.whl.

File metadata

  • Download URL: torchmed-0.0.1a0-py3-none-any.whl
  • Upload date:
  • Size: 34.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.29.1 CPython/3.6.6

File hashes

Hashes for torchmed-0.0.1a0-py3-none-any.whl
Algorithm Hash digest
SHA256 9ce70dbe19917bb3fa8795e37a45915c56ba9f07b0d596561555feee6c0367df
MD5 bc7983fd2dec09bb564cc046813ded14
BLAKE2b-256 203a3d3320e0e839bb4a8715ae15e610ccaa348c01b82381ffe7b253d463e105

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page