This module defines tools for image data preprocessing and real-time data augmentation that is used to train a model.

Plot images and masks


show(*obj, file_name=None, overlay=False, pred=False, num_classes=2, show_bbox=False, figsize=(10, 10), cmap='viridis', **kwargs)

Show image, mask, and weight (optional)

Example image and mask

We will use an example image and mask to guide through the documentation.

Plot example image and mask

image = imageio.imread(path/'images'/'01.png')
mask = imageio.imread(path/'labels'/'01_mask.png')
show(image, mask)

Mask preprocessing

Supported segmentation mask types:

Class labels: pixel annotations of classes (e.g., 0 for background and 1...n for positive classes)
Instance labels: pixel annotation of belonggig to different instance (e.g., 0 for background, 1 for first ROI, 2 for second ROI, etc.).

_show(mask, inst_labels)

The provided segmentation masks are preprocessed to

  • convert instance labels to class labels
  • draw small ridges between touching instances (optional)


preprocess_mask(clabels=None, instlabels=None, remove_connectivity=True, num_classes=2)

Calculates the weights from the given mask (classlabels clabels or instlabels).

Arguments in preprocess_masks:

  • clabels: class labels (segmentation mask),
  • instlabels: instance labels (segmentation mask),
  • num_classes (int) = number of classes
tst1 = preprocess_mask(mask, remove_connectivity=False)
tst2 = preprocess_mask(mask, remove_connectivity=True)
ind = (slice(200,230), slice(230,260))
print('Zoom in on borders:')
_show(tst1[ind], tst2[ind])
Zoom in on borders:

Data augmentation

Deformation field class to ensure that all augmentations are performed equally on images, masks, and weights. Implemented augmentations are

  • rotation
  • mirroring
  • random deformation

class DeformationField[source]

DeformationField(shape=(540, 540), scale=1, scale_range=(0, 0), p_scale=1.0)

Creates a deformation field for data augmentation

Original Image

tst = DeformationField(shape=(260, 260), scale=1, scale_range=(0.7, 1.4))
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)))

Add mirroring

tst = DeformationField()
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)))

Add rotation

show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)))


Pytorch map-style datasets for training and validation.

Helper functions

path_test = path/'mask.png'
for num_classes in [2,3,10]:
    x1 = np.random.randint(num_classes, size=(512, 512))
    save_mask(x1, path_test)
    x2 = _read_msk(path_test, num_classes=num_classes)
    test_eq(x1, x2)

Base Class


tiles_in_rectangles(H, W, h, w)

Get smaller rectangles needed to fill the larger rectangle

class BaseDataset[source]

BaseDataset(files, label_fn=None, instance_labels=False, num_classes=2, ignore={}, remove_connectivity=True, stats=None, normalize=True, use_zarr_data=True, tile_shape=(512, 512), padding=(0, 0), preproc_dir=None, verbose=1, scale=1, pdf_reshape=512, use_preprocessed_labels=False, **kwargs) :: Dataset

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many implementations and the default options of

.. note:: by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

files = get_image_files(path/'images')
label_fn = label_fn = lambda o: path/'labels'/f'{o.stem}_mask.png'#lambda o: path/'labels'/f'{o.stem}_mask{o.suffix}'
tst = BaseDataset(files, label_fn=label_fn, num_classes=2)
Preprocessing data
100.00% [1/1 00:00<00:00]
Calculated stats {'channel_means': array([100.18701303]), 'channel_stds': array([84.32689916]), 'max_tiles_per_image': 4}


For training

class RandomTileDataset[source]

RandomTileDataset(*args, sample_mult=None, flip=True, rotation_range_deg=(0, 360), scale_range=(0, 0), albumentations_tfms=[RandomGamma(always_apply=False, p=0.5, gamma_limit=(80, 120), eps=None)], min_length=400, **kwargs) :: BaseDataset

Pytorch Dataset that creates random tiles with augmentations from the input images.

Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. tile_shape - The tile shape the network expects as input padding - The padding (input shape - output shape) classlabels - A list containing the corresponding class labels. 0 = ignore, 1 = background, 2-n foreground classes If None, the problem will be treated as binary segmentation num_classes - The number of classes including background ignore - A list containing the corresponding ignore regions. weights - A list containing the corresponding weights. element_size_um - The target pixel size in micrometers batch_size - The number of tiles to generate per batch rotation_range_deg - (alpha_min, alpha_max): The range of rotation angles. A random rotation is drawn from a uniform distribution in the given range flip - If true, a coin flip decides whether a mirrored tile will be generated deformation_grid - (dx, dy): The distance of neighboring grid points in pixels for which random deformation vectors are drawn deformation_magnitude - (sx, sy): The standard deviations of the Gaussians, the components of the deformation vector are drawn from value_minimum_range - (v_min, v_max): Input intensity zero will be mapped to a random value in the given range value_maximum_range - (v_min, v_max): Input intensity one will be mapped to a random value within the given range value_slope_range - (s_min, s_max): The slope at control points is drawn from a uniform distribution in the given range

Show data

tst = RandomTileDataset(files, label_fn=label_fn, verbose=2, scale=1, num_classes=2)#, albumentations_tfms=get_aug())
Preprocessing data
100.00% [1/1 00:00<00:00]
Calculated stats {'channel_means': array([100.18701303]), 'channel_stds': array([84.32689916]), 'max_tiles_per_image': 4}

Show random tile

tile = tst[0]
show(tile[0], tile[1])
img_path = tst.files[0]
cdf = tst.pdfs[][:] 
centers = [tst._random_center(cdf, mask.shape) for _ in range(int(1e+3))]
xs = [x[1] for x in centers]
ys = [x[0] for x in centers]
plt.scatter(x=xs, y=ys, c='r', s=2)


class TileDataset[source]

TileDataset(*args, val_length=None, val_seed=42, max_tile_shift=1.0, border_padding_factor=0.25, return_index=False, **kwargs) :: BaseDataset

Pytorch Dataset that creates random tiles for validation and prediction on new data.

Show data

tst = TileDataset(files, label_fn=label_fn, num_classes=2, tile_shape=(224,224), padding=(0,0), scale=1., val_length=6)
Preprocessing data
100.00% [1/1 00:00<00:00]
Calculated stats {'channel_means': array([100.18701303]), 'channel_stds': array([84.32689916]), 'max_tiles_per_image': 9}

Show tiles on image

  • Center points are indicated with red dots
fix, axs = plt.subplots(figsize=(10,10))
axs.imshow(ndimage.zoom(tst.get_data(max_n=1)[0][...,0], 1/tst.scale))
xs = [x[1]/tst.scale for x in tst.centers]
ys = [x[0]/tst.scale for x in tst.centers]
axs.scatter(x=xs, y=ys, c='r', s=20)
for xsi, ysi in zip(xs, ys):
    o = tst.output_shape
    rect = patches.Rectangle((xsi-o[0]/2,ysi-o[1]/2), o[0], o[1], linewidth=1, edgecolor='r', alpha=0.3)