This module defines tools for image data preprocessing and real-time data augmentation that is used to train a model.

Plot images and masks

show[source]

show(*obj, file_name=None, overlay=False, pred=False, show_bbox=True, figsize=(10, 10), cmap='binary_r', **kwargs)

Show image, mask, and weight (optional)

The show methods in fastai all rely on some types being able to show themselves. We create a new type with a show method.

Typedispatch

Custom show_batch and show_results for DataLoader

Example image and mask

We will use an example image and mask to guide through the documentation.

Plot example image and mask

image = imageio.imread(path/'images'/'01.png')
mask = imageio.imread(path/'labels'/'01_mask.png')
show(image, mask)

Mask preprocessing

Supported segmentation mask types:

Class labels: pixel annotations of classes (e.g., 0 for background and 1...n for positive classes)
Instance labels: pixel annotation of belonggig to different instance (e.g., 0 for background, 1 for first ROI, 2 for second ROI, etc.).

_show(mask, inst_labels)

The provided segmentation masks are preprocessed to

  • convert instance labels to class labels
  • draw small ridges between touching instances (optional)

preprocess_mask[source]

preprocess_mask(clabels=None, instlabels=None, ignore=None, remove_overlap=True, n_dims=2)

Calculates the weights from the given mask (classlabels clabels or instlabels).

Arguments in preprocess_masks:

  • clabels: class labels (segmentation mask),
  • instlabels: instance labels (segmentation mask),
  • n_dims (int) = number of classes for clabels
tst1 = preprocess_mask(mask, remove_overlap=False)
tst2 = preprocess_mask(mask, remove_overlap=True)
_show(tst1,tst2)
ind = (slice(200,230), slice(230,260))
print('Zoom in on borders:')
_show(tst1[ind], tst2[ind])

Data augmentation

Deformation field class to ensure that all augmentations are performed equally on images, masks, and weights. Implemented augmentations are

  • rotation
  • mirroring
  • random deformation

class DeformationField[source]

DeformationField(shape=(540, 540), scale=1, scale_range=(0, 0), p_scale=1.0)

Creates a deformation field for data augmentation

Original Image

tst = DeformationField(shape=(260, 260), scale=1, scale_range=(0.7, 1.4))
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)))

Add mirroring

tst = DeformationField()
tst.mirror((1,1))
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)))

Add rotation

tst.rotate(1)
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)))

Datasets

Pytorch map-style datasets for training and validation.

Helper functions

Base Class

class BaseDataset[source]

BaseDataset(*args, **kwds) :: Dataset

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.

.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

path = Path('sample_data')
files = get_image_files(path/'images')
label_fn = label_fn = lambda o: path/'labels'/f'{o.stem}_mask.png'#lambda o: path/'labels'/f'{o.stem}_mask{o.suffix}'
tst = BaseDataset(files, label_fn=label_fn)
tst.show_data()
tst.clear_cached_weights()

RandomTileDataset

For training

class RandomTileDataset[source]

RandomTileDataset(*args, **kwds) :: BaseDataset

Pytorch Dataset that creates random tiles with augmentations from the input images.

Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. tile_shape - The tile shape the network expects as input padding - The padding (input shape - output shape) classlabels - A list containing the corresponding class labels. 0 = ignore, 1 = background, 2-n foreground classes If None, the problem will be treated as binary segmentation n_classes - The number of classes including background ignore - A list containing the corresponding ignore regions. weights - A list containing the corresponding weights. element_size_um - The target pixel size in micrometers batch_size - The number of tiles to generate per batch rotation_range_deg - (alpha_min, alpha_max): The range of rotation angles. A random rotation is drawn from a uniform distribution in the given range flip - If true, a coin flip decides whether a mirrored tile will be generated deformation_grid - (dx, dy): The distance of neighboring grid points in pixels for which random deformation vectors are drawn deformation_magnitude - (sx, sy): The standard deviations of the Gaussians, the components of the deformation vector are drawn from value_minimum_range - (v_min, v_max): Input intensity zero will be mapped to a random value in the given range value_maximum_range - (v_min, v_max): Input intensity one will be mapped to a random value within the given range value_slope_range - (s_min, s_max): The slope at control points is drawn from a uniform distribution in the given range

Show data

tst = RandomTileDataset(files, label_fn=label_fn, verbose=2, scale=1)#, albumentations_tfms=get_aug())
tst.show_data()

Show random tile (default padding = (184,184))

tile = tst[0]
show(tile[0], tile[1])

Compute stats

tst.compute_stats()
img_path = tst.files[0]
cdf = tst.pdfs[img_path.name][:] 
centers = [tst._random_center(cdf, mask.shape) for _ in range(int(5e+2))]
plt.imshow(mask)
xs = [x[1] for x in centers]
ys = [x[0] for x in centers]
plt.scatter(x=xs, y=ys, c='r', s=10)
plt.show()

TileDataset

class TileDataset[source]

TileDataset(*args, **kwds) :: BaseDataset

Pytorch Dataset that creates random tiles for validation and prediction on new data.

Show data

tst = TileDataset(files, label_fn=label_fn, tile_shape=(224,224), padding=(0,0), scale=1, val_length=6)
tst.show_data()

Show tiles on image

  • Center points are indicated with red dots
fix, axs = plt.subplots(figsize=(10,10))
axs.imshow(ndimage.zoom(tst.get_data(max_n=1)[0][...,0], 1/tst.scale))
xs = [x[1]/tst.scale for x in tst.centers]
ys = [x[0]/tst.scale for x in tst.centers]
axs.scatter(x=xs, y=ys, c='r', s=20)
for xsi, ysi in zip(xs, ys):
    o = tst.output_shape
    rect = patches.Rectangle((xsi-o[0]/2,ysi-o[1]/2), o[0], o[1], linewidth=1, edgecolor='r', alpha=0.3)
    axs.add_patch(rect)
plt.show()