Pytorch segmentation models.
Segmenation Models Pytorch Integration
From the website:
- High level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 104 available encoders
- All encoders have pre-trained weights for faster and better convergence
See https://github.com/qubvel/segmentation_models.pytorch for API details.
bs = 2
tile_shapes = [256, 640] #1024
in_channels = [1] #1,3,4
classes = [2, 5] # 2,5
encoders = ENCODERS[0:1]+['tu-convnext_tiny']#+ENCODERS[-1:]
archs = ARCHITECTURES[0:1]
for ts in tile_shapes:
for in_c in in_channels:
for c in classes:
inp = torch.randn(bs, in_c, ts, ts)
out_shape = [bs, c, ts, ts]
for arch in archs:
for encoder_name in encoders:
model = create_smp_model(arch=arch,
encoder_name=encoder_name,
#encoder_weights=None,
in_channels=in_c,
classes=c)
out = model(inp)
test_eq(out.shape, out_shape)
del model
arch = 'Unet'
path = 'tst.pth'
stats = (1,1)
kwargs = {'encoder_name': 'resnet34'}
tst = create_smp_model(arch, **kwargs)
path = save_smp_model(tst, arch, path, stats=stats)
tst2, stats2 = load_smp_model(path)
for p1, p2 in zip(tst.parameters(), tst2.parameters()):
test_eq(p1.detach(), p2.detach())
test_eq(stats, stats2)
path.unlink()
Cellpose integration
for reliable cell and nucleus segmentation. Visit cellpose for more information.
Cellpose integration for deepflash2 is tested on version 0.6.6.dev13+g316927e
probs = [np.random.rand(512,512)]
masks = [x>0. for x in probs]
cp_preds = run_cellpose(probs, masks, diameter=17.)
test_eq(probs[0].shape, cp_preds[0].shape)