Skip to main content

Subclassing your dataset class into a new class that can be used for tiled prediction to obtain the stitched prediction.

Project description

A lean wrapper around your dataset class to enable tiled prediction.

License CI codecov

Objective

This package subclasses the dataset class you use to train your network. With PredTiler, you can use your dataset class as is, and PredTiler will take care of the tiling logic for you. It will automatically generate patches in such a way that they can be tiled with the overlap of (patch_size - tile_size)//2. We also provide a function to stitch the tiles back together to get the final prediction.

In case you are facing issues, feel free to raise an issue and I will be happy to help you out ! In future, I plan to add detailed instructions for:

  1. multi-channel data
  2. 3D data
  3. Data being a list of numpy arrays, each poissibly having different shapes.

Installation

pip install predtiler

Usage

To work with PredTiler, the only requirement is that your dataset class must have a patch_location(self, index) method that returns the location of the patch at the given index. Your dataset class should only use the location information returned by this method to return the patch. PredTiler will override this method to return the location of the patches needed for tiled prediction.

Note that your dataset class could be arbitrarily complex (augmentations, returning multiple patches, working with 3D data, etc.). The only requirement is that it should use the crop present at the location returned by patch_location method. Below is an example of a simple dataset class that can be used with PredTiler.

class YourDataset:
    def __init__(self, data_path, patch_size=64) -> None:
        self.patch_size = patch_size
        self.data = load_data(data_path) # shape: (N, H, W, C)

    def patch_location(self, index:int)-> Tuple[int, int, int]:
        # it just ignores the index and returns a random location
        n_idx = np.random.randint(0,len(self.data))
        h = np.random.randint(0, self.data.shape[1]-self.patch_size)
        w = np.random.randint(0, self.data.shape[2]-self.patch_size)
        return (n_idx, h, w)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        n_idx, h, w = self.patch_location(index)
        # return the patch at the location (patch_size, patch_size)
        return self.data[n_idx, h:h+self.patch_size, w:w+self.patch_size]

Getting overlapping patches needed for tiled prediction

To use PredTiler, we need to get a new class that wraps around your dataset class. For this we also need a tile manager that will manage the tiles.

from predtiler import get_tiling_dataset, get_tile_manager, stitch_predictions
patch_size = 256
tile_size = 128
data_shape = (10, 2048, 2048) # size of the data you are working with
manager = get_tile_manager(data_shape=data_shape, tile_shape=(1,tile_size,tile_size), 
                               patch_shape=(1,patch_size,patch_size))
    
dset_class = get_tiling_dataset(YourDataset, manager)

At this point, you can use the dset_class as you would use YourDataset class.

data_path = ... # path to your data
dset = dset_class(data_path, patch_size=patch_size)

Stitching the predictions

The benefit of using PredTiler is that it will automatically generate the patches in such a way that they can be tiled with the overlap of (patch_size - tile_size)//2. This allows you to use your dataset class as is, without worrying about the tiling logic.

model = ... # your model
predictions = []
for i in range(len(dset)):
    inp = dset[i]
    inp = torch.Tensor(inp)[None,None]
    pred = model(inp)
    predictions.append(pred[0].numpy())

predictions = np.stack(predictions) # shape: (number_of_patches, C, patch_size, patch_size)
stitched_pred = stitch_predictions(predictions, dset.tile_manager)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

predtiler-0.0.2.tar.gz (11.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

predtiler-0.0.2-py3-none-any.whl (7.5 kB view details)

Uploaded Python 3

File details

Details for the file predtiler-0.0.2.tar.gz.

File metadata

  • Download URL: predtiler-0.0.2.tar.gz
  • Upload date:
  • Size: 11.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.28.1

File hashes

Hashes for predtiler-0.0.2.tar.gz
Algorithm Hash digest
SHA256 374330bb2b0ee8dda8462d2873ac5907320b6e348023a0bcfb4f9d660057065e
MD5 ca05e73dcd98a676ab09d09d30335831
BLAKE2b-256 5015d84a181a509ff0b0c05c90b38adbc857baab1a3d1ae820d9804f2468e811

See more details on using hashes here.

File details

Details for the file predtiler-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: predtiler-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 7.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.28.1

File hashes

Hashes for predtiler-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 9c509d68547b3b95a6e96a5febf65bcc42af791ca49c933d88397a756d74c3ec
MD5 ef05e0ec6689e29d3050279b161dba3b
BLAKE2b-256 724eef0d553d521798cdd1f8d21ff78f572932ccdd048c927976901258960048

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page