Pytorch segmentation models.

Segmenation Models Pytorch Integration

From the website:

  • High level API (just two lines to create a neural network)
  • 9 models architectures for binary and multi class segmentation (including legendary Unet)
  • 104 available encoders
  • All encoders have pre-trained weights for faster and better convergence

See https://github.com/qubvel/segmentation_models.pytorch for API details.

get_pretrained_options[source]

get_pretrained_options(encoder_name)

Return available options for pretrained weights for a given encoder

create_smp_model[source]

create_smp_model(arch, **kwargs)

Create segmentation_models_pytorch model

bs = 2
tile_shapes = [512] #1024
in_channels = [1] #1,3,4
classes = [2] # 2,5
encoders = ENCODERS[1:2]#+ENCODERS[-1:]

for ts in tile_shapes:
    for in_c in in_channels:
        for c in classes:
            inp = torch.randn(bs, in_c, ts, ts)
            out_shape = [bs, c, ts, ts]
            for arch in ARCHITECTURES:
                for encoder_name in encoders:
                    model = create_smp_model(arch=arch, 
                                             encoder_name=encoder_name,
                                             encoder_weights=None,
                                             in_channels=in_c, 
                                             classes=c)
                    out = model(inp)
                    test_eq(out.shape, out_shape)

save_smp_model[source]

save_smp_model(model, arch, file, stats=None, pickle_protocol=2)

Save smp model, optionally including stats

arch = 'Unet'
file = 'tst.pth'
stats = (1,1)
kwargs = {'encoder_name': 'resnet34'}
tst = create_smp_model(arch, **kwargs)
save_smp_model(tst, arch, file, stats=stats)

load_smp_model[source]

load_smp_model(file, device=None, strict=True, **kwargs)

Loads smp model from file

tst2, stats2 = load_smp_model(file)
for p1, p2 in zip(tst.parameters(), tst2.parameters()):
    test_eq(p1.detach(), p2.detach())
test_eq(stats, stats2)