Example image and mask
We will use an example image and mask to guide through the documentation.
Plot example image and mask
image = imageio.imread(path/'images'/'01.png')
mask = imageio.imread(path/'labels'/'01_mask.png')
show(image, mask)
Supported segmentation mask types:
Class labels: pixel annotations of classes (e.g., 0 for background and 1...n for positive classes)
Instance labels: pixel annotation of belonggig to different instance (e.g., 0 for background, 1 for first ROI, 2 for second ROI, etc.).
_show(mask, inst_labels)
The provided segmentation masks are preprocessed to
- convert instance labels to class labels
- draw small ridges between touching instances (optional)
Arguments in preprocess_masks
:
clabels
: class labels (segmentation mask),instlabels
: instance labels (segmentation mask),num_classes
(int) = number of classes
tst1 = preprocess_mask(mask, remove_connectivity=False)
tst2 = preprocess_mask(mask, remove_connectivity=True)
_show(tst1,tst2)
ind = (slice(200,230), slice(230,260))
print('Zoom in on borders:')
_show(tst1[ind], tst2[ind])
Deformation field class to ensure that all augmentations are performed equally on images, masks, and weights. Implemented augmentations are
- rotation
- mirroring
- random deformation
Original Image
tst = DeformationField(shape=(260, 260), scale=1, scale_range=(0.7, 1.4))
show(tst.apply(image, offset=(270,270)),
tst.apply(mask, offset=(270,270)))
Add mirroring
tst = DeformationField()
tst.mirror((1,1))
show(tst.apply(image, offset=(270,270)),
tst.apply(mask, offset=(270,270)))
Add rotation
tst.rotate(1)
show(tst.apply(image, offset=(270,270)),
tst.apply(mask, offset=(270,270)))
Datasets
Pytorch map-style datasets for training and validation.
path_test = path/'mask.png'
for num_classes in [2,3,10]:
x1 = np.random.randint(num_classes, size=(512, 512))
save_mask(x1, path_test)
x2 = _read_msk(path_test, num_classes=num_classes)
test_eq(x1, x2)
path_test.unlink()
files = get_image_files(path/'images')
label_fn = label_fn = lambda o: path/'labels'/f'{o.stem}_mask.png'#lambda o: path/'labels'/f'{o.stem}_mask{o.suffix}'
tst = BaseDataset(files, label_fn=label_fn, num_classes=2)
tst.show_data()
Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. tile_shape - The tile shape the network expects as input padding - The padding (input shape - output shape) classlabels - A list containing the corresponding class labels. 0 = ignore, 1 = background, 2-n foreground classes If None, the problem will be treated as binary segmentation num_classes - The number of classes including background ignore - A list containing the corresponding ignore regions. weights - A list containing the corresponding weights. element_size_um - The target pixel size in micrometers batch_size - The number of tiles to generate per batch rotation_range_deg - (alpha_min, alpha_max): The range of rotation angles. A random rotation is drawn from a uniform distribution in the given range flip - If true, a coin flip decides whether a mirrored tile will be generated deformation_grid - (dx, dy): The distance of neighboring grid points in pixels for which random deformation vectors are drawn deformation_magnitude - (sx, sy): The standard deviations of the Gaussians, the components of the deformation vector are drawn from value_minimum_range - (v_min, v_max): Input intensity zero will be mapped to a random value in the given range value_maximum_range - (v_min, v_max): Input intensity one will be mapped to a random value within the given range value_slope_range - (s_min, s_max): The slope at control points is drawn from a uniform distribution in the given range
Show data
tst = RandomTileDataset(files, label_fn=label_fn, verbose=2, scale=1, num_classes=2)#, albumentations_tfms=get_aug())
tst.show_data()
Show random tile
tile = tst[0]
show(tile[0], tile[1])
img_path = tst.files[0]
cdf = tst.pdfs[img_path.name][:]
centers = [tst._random_center(cdf, mask.shape) for _ in range(int(1e+3))]
plt.imshow(mask)
xs = [x[1] for x in centers]
ys = [x[0] for x in centers]
plt.scatter(x=xs, y=ys, c='r', s=2)
plt.show()
Show data
tst = TileDataset(files, label_fn=label_fn, num_classes=2, tile_shape=(224,224), padding=(0,0), scale=1., val_length=6)
tst.show_data()
Show tiles on image
- Center points are indicated with red dots
fix, axs = plt.subplots(figsize=(10,10))
axs.imshow(ndimage.zoom(tst.get_data(max_n=1)[0][...,0], 1/tst.scale))
xs = [x[1]/tst.scale for x in tst.centers]
ys = [x[0]/tst.scale for x in tst.centers]
axs.scatter(x=xs, y=ys, c='r', s=20)
for xsi, ysi in zip(xs, ys):
o = tst.output_shape
rect = patches.Rectangle((xsi-o[0]/2,ysi-o[1]/2), o[0], o[1], linewidth=1, edgecolor='r', alpha=0.3)
axs.add_patch(rect)
plt.show()
shutil.rmtree(path)