Skip to main content

Extract and Merge Batches/Image patches (tf/torch) for easy, fast and self-contained digital image processing and deep learning model training.

Project description

License: MIT Generic badge Downloads Hits

Extract and Merge Image Patches (EMPatches)

Extract and Merge Batches/Image patches (tf/torch), fast and self-contained digital image processing and deep learning model training.

  • Extract patches
  • Merge the extracted patches to obtain the original image back.

Upadate 0.2.3 (Bug Fix)

  • While merging tensors.
    thanks MRLBradley for noticing.

Upadate 0.2.2 (New Functionalities)

  • Handling 1D spectral and 3D volumetric data structures, thanks to antonyvam.
  • Batch processing support for 1D, 2D, 3D (image/pixel + voxel/volumetric) data added.
  • Bug fixes for multi-dimensional image patch merging for C > 3.

Update 0.2.0

  • Handling of tensorflow/pytorch Batched images of shape BxCxHxW -> pytorch or BxHxWxC -> tf. C can be any number not limited to just RGB channels.
  • Modes added for mergeing patches.
    1. overwrite: next patch will overwrite the overlapping area of the previous patch.
    2. max : maximum value of overlapping area at each pixel will be written.
    3. min: minimum value of overlapping area at each pixel will be written.
    4. avg : mean/average value of overlapping area at each pixel will be written.
  • Patching via providing Indices.
  • Strided patching thanks to Andreasgejlm

Dependencies

python >= 3.6
numpy 
math

Usage

Extracting Patches

from empatches import EMPatches
import imgviz # just for plotting

# get image either RGB or Grayscale
img = cv2.imread('../digits.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

alt text

# load module
emp = EMPatches()
img_patches, indices = emp.extract_patches(img, patchsize=512, overlap=0.2)

# displaying 1st 10 image patches
tiled= imgviz.tile(list(map(np.uint8, img_patches)),border=(255,0,0))
plt.figure()
plt.imshow(tiled)

alt text

Image Processing

Now we can perform our operation on each patch independently and after we are done we can merge them back together.

'''
pseudo code
'''
# do some processing, just store the patches in the list in same order
img_patches_processed = some_processing_func(img_patches)
# or run your deep learning model on patches independently and then merge the predictions
img_patches_processed = model.predict(img_patches)
'''For now lets just flip channels'''
img_patches[1] = cv2.cvtColor(img_patches[1], cv2.COLOR_BGR2RGB)

alt text

Merging-Patches

After processing the patches if you can merge all of them back in original form as follows,

merged_img = emp.merge_patches(img_patches, indices, mode='max') # or
merged_img = emp.merge_patches(img_patches, indices, mode='min') # or
merged_img = emp.merge_patches(img_patches, indices, mode='overwrite') # or
merged_img = emp.merge_patches(img_patches, indices, mode='avg') # or
# display
plt.figure()
plt.imshow(merged_img.astype(np.uint8))
plt.title(Your mode)

alt text

Strided Patching

img_patches, indices = emp.extract_patches(img, patchsize=512, overlap=0.2, stride=128)
tiled= imgviz.tile(list(map(np.uint8, img_patches)),border=(255,0,0))
plt.figure()
plt.imshow(tiled.astype(np.uint8))
plt.title('Strided patching')

alt text

Volumetric/Voxel data patching

# first generate a sample data
def midpoints(x):
    sl = ()
    for i in range(x.ndim):
        x = (x[sl + np.index_exp[:-1]] + x[sl + np.index_exp[1:]]) / 2.0
        sl += np.index_exp[:]
    return x
r, g, b = np.indices((17, 17, 17)) / 16.0
rc = midpoints(r)
gc = midpoints(g)
bc = midpoints(b)
# define a sphere about [0.5, 0.5, 0.5]
sphere = ((rc - 0.5)**2 + (gc - 0.5)**2 + (bc - 0.5)**2 < 0.5**2).astype(int)

ax = plt.figure().add_subplot(projection='3d')
ax.voxels(sphere)
plt.title(f'Voxel 3D data: {sphere.shape} shape')

Extract patches from voxel 3D data.

emp = EMPatches()
patches, indices  = emp.extract_patches(sphere, patchsize=8, overlap=0.0, stride=None, vox=True)

ax = plt.figure().add_subplot(projection='3d')
ax.voxels(patches[1])
plt.title(f'Patched Voxel 3D data: {patches[0].shape} shape')

for i in range(len(patches)):
    print(patches[i].shape)

mp = emp.merge_patches(patches, indices)
###############___VOXEL DATA___ setting vox to True ########################
##  shape     indices in xyz dimension
>> (8, 8, 8) (0, 8, 0, 8, 0, 8)
>> (8, 8, 8) (0, 8, 0, 8, 8, 16)
>> (8, 8, 8) (8, 16, 0, 8, 0, 8)
>> (8, 8, 8) (8, 16, 0, 8, 8, 16)
>> (8, 8, 8) (0, 8, 8, 16, 0, 8)
>> (8, 8, 8) (0, 8, 8, 16, 8, 16)
>> (8, 8, 8) (8, 16, 8, 16, 0, 8)
>> (8, 8, 8) (8, 16, 8, 16, 8, 16)

alt text

⚠️NOTE⚠️

Here the output shape is 8x8x8 i.e. the croping is also done in D/C dimension unlike when we are doing image croping/patching in that case the output would have shape 8x8x3 (RGB) or 8x8 (grayscale), and incides would be like.

###############___PIXEL DATA___ -> setting vox to False ########################
##  shape     indices in xy dimension
>> (8, 8, 16) (0, 8, 0, 8)
>> (8, 8, 16) (8, 16, 0, 8)
>> (8, 8, 16) (0, 8, 8, 16)
>> (8, 8, 16) (8, 16, 8, 16)

alt text

1D spectral Data patching

x1 = np.linspace(0.0, 5.0)
y1 = np.cos(5 * np.pi * x1) * np.exp(-x1)
plt.plot(y1)
plt.title('1D spectra')

emp = EMPatches()
patches, indices  = emp.extract_patches(y1, patchsize=8, overlap=0.0, stride=None)

alt text

ax1 = plt.subplot(1)
plt.plot(patches[0]) # 0th patch
ax2 = plt.subplot(2, sharex=ax1, sharey=ax1)
plt.plot(patches[2]) # 2nd pathc
plt.suptitle('patched 1D spectra')
# merge again
mp = emp.merge_patches(patches, indices)

alt text

Batched Patching

Things to know.

  • batch : Batch of images of shape either BxCxHxW -> pytorch or BxHxWxC -> tf to extract patches from in list(list1, list2, ...), where, list1->([H W C], [H W C], ...) and so on.

  • patchsize : size of patch to extract from image only square patches can be extracted for now.

  • overlap (Optional): overlap between patched in percentage a float between [0, 1].

  • stride (Optional): Step size between patches

  • type (Optional): Type of batched images tf or torch type

  • batch_patches : a list containing lists of extracted patches of images.

  • batch_indices : a list containing lists of indices of patches in order, whihc can be used at later stage for 'merging_patches'.

  • merged_batch : a np array of shape BxCxHxW -> pytorch or BxHxWxC -> tf.

Extraction

from empatches import BatchPatching

bp = BatchPatching(patchsize=512, overlap=0.2, stride=None, typ='torch')
# extracging
batch_patches, batch_indices = bp.patch_batch(batch) # batch of shape BxCxHxW, C can be any number 3 or greater

plt.imshow(batch_patches[1][2])
plt.title('3rd patch of 2nd image in batch')

alt text

Merging

# merging
# output will be of shpae depending on typ variable
# BxCxHxW -> torch or BxHxWxC -> tf
merged_batch = bp.merge_batch(batch_patches, batch_indices, mode='avg') 

# accessing the merged images
plt.imshow(merged_batch[1,...].astype(np.uint8))
plt.title('2nd merged image in batch')

alt text

Patching via Providing Indices

NOTE in this case merging is not supported.

from empatches import patch_via_indices

img = cv2.imread('./digit.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (1024, 512))

i = [(0, 512, 0, 256),  # 1st patch dims/indices
     (0, 256, 310, 922),# 2nd patch dims/indices
     (0, 512, 512, 768)]# 3rd patch dims/indices
img_patches = patch_via_indices(img, indices)

# plotting
tiled= imgviz.tile(list(map(np.uint8, img_patches)),border=(255,0,0))
plt.figure()
plt.imshow(tiled.astype(np.uint8))
plt.title('patching via providing indices')

alt text

For more infomration visit Homepage.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

empatches-0.2.3.tar.gz (11.2 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page