Implements popular segmentation loss functions.
Losses implemented here:
Wrapper for handling different tensor types from fastai.
Wrapper for combining different losses, adapted from from pytorch-toolbelt
Popular segmentation losses
The get_loss()
function loads popular segmentation losses from Segmenation Models Pytorch and kornia:
- (Soft) CrossEntropy Loss
- Dice Loss
- Jaccard Loss
- Focal Loss
- Lovasz Loss
- TverskyLoss
n_classes=2
output = torch.randn(4, n_classes, 356, 356, requires_grad=True)
target = torch.randint(0, n_classes, (4, 356, 356))
tst = Poly1CrossEntropyLoss(num_classes=n_classes)
loss = tst(output, target)
n_classes = 2
#output = TensorImage(torch.randn(4, n_classes, 356, 356, requires_grad=True))
#target = TensorMask(torch.randint(0, n_classes, (4, 356, 356)))
output = torch.randn(4, n_classes, 356, 356, requires_grad=True)
target = torch.randint(0, n_classes, (4, 356, 356))
for loss_name in LOSSES:
print(f'Testing {loss_name}')
tst = get_loss(loss_name, classes=list(range(1,n_classes)))
loss = tst(output, target)
ce1 = get_loss('SoftCrossEntropyLoss', smooth_factor=0)
ce2 = CrossEntropyLossFlat(axis=1)
test_close(ce1(output, target), ce2(output, target), eps=1e-04)
jc = get_loss('JaccardLoss')
dc = get_loss('DiceLoss')
dc_loss = dc(output, target)
dc_to_jc = 2*dc_loss/(dc_loss+1) #it seems to be the other way around?
test_close(jc(output, target), dc_to_jc, eps=1e-02)
tw = get_loss("TverskyLoss", alpha=0.5, beta=0.5)
test_close(dc(output, target), tw(output, target), eps=1e-02)
output = torch.randn(4, n_classes, 356, 356)
output[:,1,...] = 0.5
tst = get_loss(loss_name='DiceLoss', classes=None)
tst2 = get_loss(loss_name='DiceLoss', classes=list(range(1,n_classes)))
test_ne(tst(output, target), tst2(output, target))