import argparse import torch import torch.nn as nn from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator from .runner import get_nn_runners def barf(): import pdb pdb.set_trace() def assertEqual(tensor, expected, threshold=0.001): if isinstance(tensor, list) or isinstance(tensor, tuple): for t, e in zip(tensor, expected): assertEqual(t, e) else: if (tensor - expected).abs().max() > threshold: barf() def filter_requires_grad(tensors): return [t for t in tensors if t.requires_grad] def test_rnns(experim_creator, control_creator, check_grad=True, verbose=False, seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, miniBatch=64, device='cuda', seed=17): creator_args = dict(seqLength=seqLength, numLayers=numLayers, inputSize=inputSize, hiddenSize=hiddenSize, miniBatch=miniBatch, device=device, seed=seed) print("Setting up...") control = control_creator(**creator_args) experim = experim_creator(**creator_args) # Precondition assertEqual(experim.inputs, control.inputs) assertEqual(experim.params, control.params) print("Checking outputs...") control_outputs = control.forward(*control.inputs) experim_outputs = experim.forward(*experim.inputs) assertEqual(experim_outputs, control_outputs) print("Checking grads...") assert control.backward_setup is not None assert experim.backward_setup is not None assert control.backward is not None assert experim.backward is not None control_backward_inputs = control.backward_setup(control_outputs, seed) experim_backward_inputs = experim.backward_setup(experim_outputs, seed) control.backward(*control_backward_inputs) experim.backward(*experim_backward_inputs) control_grads = [p.grad for p in control.params] experim_grads = [p.grad for p in experim.params] assertEqual(experim_grads, control_grads) if verbose: print(experim.forward.graph_for(*experim.inputs)) print('') def test_vl_py(**test_args): # XXX: This compares vl_py with vl_lstm. # It's done this way because those two don't give the same outputs so # the result isn't an apples-to-apples comparison right now. control_creator = varlen_pytorch_lstm_creator name, experim_creator, context = get_nn_runners('vl_py')[0] with context(): print('testing {}...'.format(name)) creator_keys = [ 'seqLength', 'numLayers', 'inputSize', 'hiddenSize', 'miniBatch', 'device', 'seed' ] creator_args = {key: test_args[key] for key in creator_keys} print("Setting up...") control = control_creator(**creator_args) experim = experim_creator(**creator_args) # Precondition assertEqual(experim.inputs, control.inputs[:2]) assertEqual(experim.params, control.params) print("Checking outputs...") control_out, control_hiddens = control.forward(*control.inputs) control_hx, control_cx = control_hiddens experim_out, experim_hiddens = experim.forward(*experim.inputs) experim_hx, experim_cx = experim_hiddens experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2) assertEqual(experim_padded, control_out) assertEqual(torch.cat(experim_hx, dim=1), control_hx) assertEqual(torch.cat(experim_cx, dim=1), control_cx) print("Checking grads...") assert control.backward_setup is not None assert experim.backward_setup is not None assert control.backward is not None assert experim.backward is not None control_backward_inputs = control.backward_setup( (control_out, control_hiddens), test_args['seed']) experim_backward_inputs = experim.backward_setup( (experim_out, experim_hiddens), test_args['seed']) control.backward(*control_backward_inputs) experim.backward(*experim_backward_inputs) control_grads = [p.grad for p in control.params] experim_grads = [p.grad for p in experim.params] assertEqual(experim_grads, control_grads) if test_args['verbose']: print(experim.forward.graph_for(*experim.inputs)) print('') if __name__ == '__main__': parser = argparse.ArgumentParser(description='Test lstm correctness') parser.add_argument('--seqLength', default='100', type=int) parser.add_argument('--numLayers', default='1', type=int) parser.add_argument('--inputSize', default='512', type=int) parser.add_argument('--hiddenSize', default='512', type=int) parser.add_argument('--miniBatch', default='64', type=int) parser.add_argument('--device', default='cuda', type=str) parser.add_argument('--check_grad', default='True', type=bool) parser.add_argument('--variable_lstms', action='store_true') parser.add_argument('--seed', default='17', type=int) parser.add_argument('--verbose', action='store_true') parser.add_argument('--rnns', nargs='*', help='What to run. jit_premul, jit, etc') args = parser.parse_args() if args.rnns is None: args.rnns = ['jit_premul', 'jit'] print(args) if 'cuda' in args.device: assert torch.cuda.is_available() rnn_runners = get_nn_runners(*args.rnns) should_test_varlen_lstms = args.variable_lstms test_args = vars(args) del test_args['rnns'] del test_args['variable_lstms'] if should_test_varlen_lstms: test_vl_py(**test_args) for name, creator, context in rnn_runners: with context(): print('testing {}...'.format(name)) test_rnns(creator, pytorch_lstm_creator, **test_args)