Skip to main content

Converting your dataset class into a class that can be used for tiled prediction and eventually obtain stiched prediction.

Project description

A lean wrapper around your dataset class to enable tiled prediction.

License CI codecov

Objective

This package subclasses the dataset class you use to train your network. With PredTiler, you can use your dataset class as is, and PredTiler will take care of the tiling logic for you. It will automatically generate patches in such a way that they can be tiled with the overlap of (patch_size - tile_size)//2. We also provide a function to stitch the tiles back together to get the final prediction.

In case you are facing issues, feel free to raise an issue and I will be happy to help you out ! In future, I plan to add detailed instructions for:

  1. multi-channel data
  2. 3D data
  3. Data being a list of numpy arrays, each poissibly having different shapes.

Installation

pip install predtiler

Usage

To work with PredTiler, the only requirement is that your dataset class must have a patch_location(self, index) method that returns the location of the patch at the given index. Your dataset class should only use the location information returned by this method to return the patch. PredTiler will override this method to return the location of the patches needed for tiled prediction.

Note that your dataset class could be arbitrarily complex (augmentations, returning multiple patches, working with 3D data, etc.). The only requirement is that it should use the crop present at the location returned by patch_location method. Below is an example of a simple dataset class that can be used with PredTiler.

class YourDataset:
    def __init__(self, data_path, patch_size=64) -> None:
        self.patch_size = patch_size
        self.data = load_data(data_path) # shape: (N, H, W, C)

    def patch_location(self, index:int)-> Tuple[int, int, int]:
        # it just ignores the index and returns a random location
        n_idx = np.random.randint(0,len(self.data))
        h = np.random.randint(0, self.data.shape[1]-self.patch_size)
        w = np.random.randint(0, self.data.shape[2]-self.patch_size)
        return (n_idx, h, w)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        n_idx, h, w = self.patch_location(index)
        # return the patch at the location (patch_size, patch_size)
        return self.data[n_idx, h:h+self.patch_size, w:w+self.patch_size]

Getting overlapping patches needed for tiled prediction

To use PredTiler, we need to get a new class that wraps around your dataset class. For this we also need a tile manager that will manage the tiles.

from predtiler.dataset import get_tiling_dataset, get_tile_manager
patch_size = 256
tile_size = 128
data_shape = (10, 2048, 2048) # size of the data you are working with
manager = get_tile_manager(data_shape=data_shape, tile_shape=(1,tile_size,tile_size), 
                               patch_shape=(1,patch_size,patch_size))
    
dset_class = get_tiling_dataset(YourDataset, manager)

At this point, you can use the dset_class as you would use YourDataset class.

data_path = ... # path to your data
dset = dset_class(data_path, patch_size=patch_size)

Stitching the predictions

The benefit of using PredTiler is that it will automatically generate the patches in such a way that they can be tiled with the overlap of (patch_size - tile_size)//2. This allows you to use your dataset class as is, without worrying about the tiling logic.

model = ... # your model
predictions = []
for i in range(len(dset)):
    inp = dset[i]
    inp = torch.Tensor(inp)[None,None]
    pred = model(inp)
    predictions.append(pred[0].numpy())

predictions = np.stack(predictions) # shape: (number_of_patches, C, patch_size, patch_size)
stitched_pred = stitch_predictions(predictions, dset.tile_manager)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

predtiler-0.0.1.tar.gz (11.0 kB view details)

Uploaded Source

Built Distribution

predtiler-0.0.1-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file predtiler-0.0.1.tar.gz.

File metadata

  • Download URL: predtiler-0.0.1.tar.gz
  • Upload date:
  • Size: 11.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for predtiler-0.0.1.tar.gz
Algorithm Hash digest
SHA256 ed54d54f1e7a8b4395368b669681fc4db80ef814ec3186c83714e7af2cfec440
MD5 7fbdc9d8637173516b1325067116818a
BLAKE2b-256 011bc9adde0776ca2efafc169d6f00d2f2c0a63940ade3872684065b79c5ae04

See more details on using hashes here.

File details

Details for the file predtiler-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: predtiler-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for predtiler-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 233f21db669e6465ee86715e07ee6aead1f7fac970c85fe878f8844d0e3b3073
MD5 943945f24a684e3696a46dd60614bf4c
BLAKE2b-256 f816bc78b8e77a04d8b3486063646b0549d92687fcb28a8462f7b9fc0f37f8cb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page