Pytorch segmentation models.
/usr/local/Caskroom/miniforge/base/envs/nbdev/lib/python3.10/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.
  warnings.warn("Setuptools is replacing distutils.")

Segmenation Models Pytorch Integration

From the website:

  • High level API (just two lines to create a neural network)
  • 9 models architectures for binary and multi class segmentation (including legendary Unet)
  • 104 available encoders
  • All encoders have pre-trained weights for faster and better convergence

See https://github.com/qubvel/segmentation_models.pytorch for API details.

get_pretrained_options[source]

get_pretrained_options(encoder_name)

Return available options for pretrained weights for a given encoder

UnetDecoder.forward[source]

UnetDecoder.forward(*features)

create_smp_model[source]

create_smp_model(arch, **kwargs)

Create segmentation_models_pytorch model

bs = 2
tile_shapes = [256, 640] #1024
in_channels = [1] #1,3,4
classes = [2, 5] # 2,5
encoders = ENCODERS[0:1]+['tu-convnext_tiny']#+ENCODERS[-1:]
archs = ARCHITECTURES[0:1]

for ts in tile_shapes:
    for in_c in in_channels:
        for c in classes:
            inp = torch.randn(bs, in_c, ts, ts)
            out_shape = [bs, c, ts, ts]
            for arch in archs:
                for encoder_name in encoders:
                    model = create_smp_model(arch=arch, 
                                             encoder_name=encoder_name,
                                             #encoder_weights=None,
                                             in_channels=in_c, 
                                             classes=c)
                    out = model(inp)
                    test_eq(out.shape, out_shape)
del model

save_smp_model[source]

save_smp_model(model, arch, path, stats=None, pickle_protocol=2)

Save smp model, optionally including stats

arch = 'Unet'
path = 'tst.pth'
stats = (1,1)
kwargs = {'encoder_name': 'resnet34'}
tst = create_smp_model(arch, **kwargs)
path = save_smp_model(tst, arch, path, stats=stats)

load_smp_model[source]

load_smp_model(path, device=None, strict=True, **kwargs)

Loads smp model from file

tst2, stats2 = load_smp_model(path)
for p1, p2 in zip(tst.parameters(), tst2.parameters()):
    test_eq(p1.detach(), p2.detach())
test_eq(stats, stats2)
path.unlink()

Cellpose integration

for reliable cell and nucleus segmentation. Visit cellpose for more information.

Cellpose integration for deepflash2 is tested on version 0.6.6.dev13+g316927e

check_cellpose_installation[source]

check_cellpose_installation(show_progress=True)

get_diameters[source]

get_diameters(masks)

Get diameters from deepflash2 prediction

run_cellpose[source]

run_cellpose(probs, masks, model_type='nuclei', diameter=0, min_size=-1, gpu=True, flow_threshold=0.4)

Run cellpose on deepflash2 predictions

probs = [np.random.rand(512,512)]
masks = [x>0. for x in probs]
cp_preds = run_cellpose(probs, masks, diameter=17.)
test_eq(probs[0].shape, cp_preds[0].shape)
Using diameter of 17.0
2022-07-12 14:52:03,044 [INFO] WRITING LOG OUTPUT TO /home/magr/.cellpose/run.log
2022-07-12 14:52:05,343 [INFO] ** TORCH CUDA version installed and working. **
2022-07-12 14:52:05,343 [INFO] >>>> using GPU
2022-07-12 14:52:05,392 [INFO] ~~~ FINDING MASKS ~~~
2022-07-12 14:52:07,577 [INFO] >>>> TOTAL TIME 2.19 sec