Code adapted from https://github.com/qubvel/ttach.
tst_base = BaseTransform('x', (1,2))
torch.jit.script(tst_base)
imgs = torch.randn(4, 2, 356, 356)
t = torch.jit.script(HorizontalFlip())
aug = t(imgs, 1, True)
deaug = t(aug, 1, False)
test_eq(imgs, deaug)
t = torch.jit.script(VerticalFlip())
aug = t(imgs, 1, True)
deaug = t(aug, 1, False)
test_eq(imgs, deaug)
t = torch.jit.script(Rotate90([180]))
aug = t(imgs, 90, False)
deaug = t(aug, 90, True)
test_eq(imgs, deaug)
tfms=[HorizontalFlip(),VerticalFlip(), Rotate90(angles=[90,180,270])]
args = [1, 1, 90]
tst_chain = torch.jit.script(Chain(tfms))
tst_chain_deaug = torch.jit.script(Chain(tfms[::-1]))
aug = tst_chain(imgs, args, False)
deaug = tst_chain_deaug(aug, args[::-1], True)
test_eq(imgs, deaug)
tst_tfm = Transformer(tfms, args)
torch.jit.script(tst_tfm)
aug = tst_tfm.augment(imgs)
deaug = tst_tfm.deaugment(aug)
test_eq(imgs, deaug)
c = Compose(tfms)
c = torch.jit.script(c)
out = []
for t in c.items:
aug = t.augment(imgs)
deaug = t.deaugment(aug)
out.append(deaug)
test_eq(imgs, deaug)
out = torch.stack(out)
test_close(imgs, torch.mean(out, dim=0))