import torch @torch.jit.script def fn(x, scale, shift): return scale * x / shift @torch.jit.script def recurrent(x, scale, shift): y = x for i in range(100): y = fn(y, scale, shift) return y x = torch.randn(2, 2, device='cuda') scale = torch.randn(2, 2, device='cuda', requires_grad=True) shift = torch.randn(2, 2, device='cuda', requires_grad=True) inputs = [x, scale, shift] out = recurrent(x, scale, shift) recurrent.graph_for(x, scale, shift) import torch @torch.jit.script def recurrent_scaleshift(x, scale, shift): y = x for i in range(64): y = scale * y + shift return y x = torch.randn(2, 2, device='cuda') scale = torch.randn(2, 2, device='cuda', requires_grad=True) shift = torch.randn(2, 2, device='cuda', requires_grad=True) inputs = [x, scale, shift] out = recurrent_scaleshift(x, scale, shift) recurrent_scaleshift.graph_for(x, scale, shift) import torch x = torch.tensor([]) x.requires_grad = True x.mean().backward() # no error triggered x = x.cuda() x.mean().backward()