Implements popular segmentation loss functions.

Losses implemented here:

Loss Wrapper functions

Wrapper for handling different tensor types from fastai.

class FastaiLoss[source]

FastaiLoss(loss, axis=1) :: _Loss

Wrapper class around loss function for handling different tensor types.

Wrapper for combining different losses, adapted from from pytorch-toolbelt

class WeightedLoss[source]

WeightedLoss(loss, weight=1.0) :: _Loss

Wrapper class around loss function that applies weighted with fixed factor. This class helps to balance multiple losses if they have different scales

class JointLoss[source]

JointLoss(first:Module, second:Module, first_weight=1.0, second_weight=1.0) :: _Loss

Wrap two loss functions into one. This class computes a weighted sum of two losses.

The get_loss() function loads popular segmentation losses from Segmenation Models Pytorch and kornia:

  • (Soft) CrossEntropy Loss
  • Dice Loss
  • Jaccard Loss
  • Focal Loss
  • Lovasz Loss
  • TverskyLoss


get_loss(loss_name, mode='multiclass', classes=[1], smooth_factor=0.0, alpha=0.5, beta=0.5, gamma=2.0, reduction='mean', **kwargs)

Load losses from based on loss_name

n_classes = 2
#output = TensorImage(torch.randn(4, n_classes, 356, 356, requires_grad=True))
#target = TensorMask(torch.randint(0, n_classes, (4, 356, 356)))
output = torch.randn(4, n_classes, 356, 356, requires_grad=True)
target = torch.randint(0, n_classes, (4, 356, 356))
for loss_name in LOSSES:
    print(f'Testing {loss_name}')
    tst = get_loss(loss_name, classes=list(range(1,n_classes))) 
    loss = tst(output, target)
Testing CrossEntropyLoss
Testing DiceLoss
Testing SoftCrossEntropyLoss
Testing CrossEntropyDiceLoss
Testing JaccardLoss
Testing FocalLoss
Testing LovaszLoss
Testing TverskyLoss
ce1 = get_loss('SoftCrossEntropyLoss', smooth_factor=0)
ce2 = CrossEntropyLossFlat(axis=1)
test_close(ce1(output, target), ce2(output, target), eps=1e-04)
jc = get_loss('JaccardLoss')
dc = get_loss('DiceLoss')
dc_loss = dc(output, target)
dc_to_jc = 2*dc_loss/(dc_loss+1) #it seems to be the other way around?
test_close(jc(output, target), dc_to_jc, eps=1e-02)
tw = get_loss("TverskyLoss", alpha=0.5, beta=0.5)
test_close(dc(output, target), tw(output, target), eps=1e-02)
output = torch.randn(4, n_classes, 356, 356)
output[:,1,...] = 0.5
tst = get_loss(loss_name='DiceLoss', classes=None) 
tst2 = get_loss(loss_name='DiceLoss', classes=list(range(1,n_classes))) 
test_ne(tst(output, target), tst2(output, target))