#!/usr/bin/env python3 # Owner(s): ["oncall: mobile"] import os import ctypes import torch from typing import Tuple from torch.backends._nnapi.prepare import convert_model_to_nnapi from torch.testing._internal.common_utils import TestCase, run_tests def qpt(t, scale, zero_point, dtype=torch.quint8): t = torch.tensor(t) return torch.quantize_per_tensor(t, scale, zero_point, dtype) def nhwc(t): t = t.clone().contiguous(memory_format=torch.channels_last) t.nnapi_nhwc = True return t class TestNNAPI(TestCase): def setUp(self): # Avoid saturation in fbgemm torch.backends.quantized.engine = 'qnnpack' libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH") if libneuralnetworks_path: ctypes.cdll.LoadLibrary(libneuralnetworks_path) print("Will attempt to run NNAPI models.") self.can_run_nnapi = True else: self.can_run_nnapi = False # Created for easy override by subclasses (eg TestNnapiBackend) def call_lowering_to_nnapi(self, traced_module, args): return convert_model_to_nnapi(traced_module, args) # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend) def set_can_run_nnapi(self, can_run): self.can_run_nnapi = can_run def check( self, module, arg_or_args, *, trace_args=None, convert_args=None, atol_rtol=None, limit=None, expected_memory_format=None ): with torch.no_grad(): if isinstance(arg_or_args, torch.Tensor): args = [arg_or_args] else: args = arg_or_args module.eval() traced = torch.jit.trace(module, trace_args or args) nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args) if not self.can_run_nnapi: # Only test that the model was converted successfully. return eager_output = module(*args) nnapi_output = nnapi_module(*args) kwargs = {} if atol_rtol is not None: kwargs["atol"] = atol_rtol[0] kwargs["rtol"] = atol_rtol[1] self.assertEqual(eager_output, nnapi_output, **kwargs) if limit is not None: mismatches = \ eager_output.int_repr().to(torch.int32) - \ nnapi_output.int_repr().to(torch.int32) if mismatches.count_nonzero() > limit: # Too many mismatches. Re-run the check with no tolerance # to get a nice message. self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0) if expected_memory_format: self.assertTrue(nnapi_output.is_contiguous(memory_format=expected_memory_format)) def float_and_quant_and_nhwc(self, inp_float, scale, zero_point): torch.manual_seed(29) inp_quant = qpt(inp_float, 0.03, 128) return [ ("float", inp_float), ("float-nhwc", nhwc(inp_float)), ("quant", inp_quant), ("quant-nhwc", nhwc(inp_quant)), ] def test_prelu(self): arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1) single_a = torch.nn.PReLU() self.check(single_a, arg) multi_a = torch.nn.PReLU(4) with torch.no_grad(): multi_a.weight.copy_(torch.tensor([.1, .2, .3, .4])) self.check(multi_a, nhwc(arg)) # Test flexible size self.check( multi_a, arg, trace_args=[torch.zeros(1, 4, 3, 3)], convert_args=[nhwc(torch.zeros(1, 4, 0, 0))], ) def test_quantize(self): self.check( torch.nn.quantized.Quantize(0.25, 2, torch.quint8), nhwc(torch.tensor([[[[1.0]], [[2.0]]]]))) def test_dequantize(self): self.check( torch.nn.quantized.DeQuantize(), nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2))) def test_unsqueeze(self): class UnsqueezeModule(torch.nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, arg): return arg.unsqueeze(self.dim) self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2)) self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2)) self.check(UnsqueezeModule(0), torch.randn(4, 2, 2)) self.check(UnsqueezeModule(1), torch.randn(4, 2, 2)) self.check(UnsqueezeModule(2), torch.randn(4, 2, 2)) def test_reshape(self): class ReshapeModule(torch.nn.Module): def __init__(self, shape): super().__init__() self.shape = shape def forward(self, arg): return arg.reshape(self.shape) self.check( ReshapeModule((2, 4)), torch.randn(4, 2, 1, 1)) self.check( ReshapeModule((8, -1)), nhwc(torch.randn(4, 2, 1, 1))) with self.assertRaisesRegex(Exception, "target size"): self.check( ReshapeModule((2, 4)), nhwc(torch.randn(4, 2, 1, 1))) def test_flatten(self): for mod in [ torch.nn.Flatten(), torch.nn.Flatten(start_dim=2, end_dim=3), torch.nn.Flatten(start_dim=2, end_dim=4), torch.nn.Flatten(start_dim=0, end_dim=-2), torch.nn.Flatten(start_dim=0, end_dim=4) ]: self.check(mod, torch.randn(4, 2, 1, 3, 7)) # flex inputs self.check( torch.nn.Flatten(), torch.randn(4, 2, 1, 3, 7), convert_args=[torch.zeros(0, 2, 1, 3, 7)] ) # channels last self.check( torch.nn.Flatten(), nhwc(torch.randn(2, 1, 4, 7)) ) self.check( torch.nn.Flatten(), nhwc(torch.randn(2, 3, 1, 1)) ) # Exceptions with self.assertRaisesRegex(Exception, "not supported on NHWC"): self.check( torch.nn.Flatten(), nhwc(torch.randn(1, 3, 4, 4)) ) with self.assertRaisesRegex(Exception, "Flattening flexible dims is not supported yet"): self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7)) with self.assertRaisesRegex(Exception, "Only 1 dim"): self.check( torch.nn.Flatten(start_dim=1, end_dim=-2), torch.randn(0, 2, 1, 3, 0)) def test_slice(self): class SliceModule(torch.nn.Module): def __init__(self, start, stop, step): super().__init__() self.start = start self.stop = stop self.step = step def forward(self, t): return t[1:, self.start:self.stop:self.step, :] class SliceModule2(torch.nn.Module): def forward(self, t): return t[3:] self.check( SliceModule(1, 5, 2), torch.randn(4, 6, 2) ) self.check( SliceModule2(), torch.randn(5) ) # flex inputs self.check( SliceModule(1, 5, 2), torch.randn(4, 6, 2), convert_args=[torch.zeros(4, 6, 0)] ) with self.assertRaisesRegex(Exception, "slice with flexible shape"): self.check( SliceModule(1, 5, 2), torch.randn(4, 6, 2), convert_args=[torch.zeros(0, 0, 0)] ) def test_cat(self): class CatModule(torch.nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t1, t2): return torch.cat([t1, t2], self.dim) self.check( CatModule(0), [ torch.randn(1, 2, 3, 3), torch.randn(2, 2, 3, 3), ]) self.check( CatModule(1), [ torch.randn(1, 2, 3, 3), torch.randn(1, 4, 3, 3), ]) self.check( CatModule(1), [ nhwc(torch.randn(1, 2, 3, 3)), nhwc(torch.randn(1, 4, 3, 3)), ]) self.check( CatModule(1), [ torch.randn(1, 2, 3, 3), torch.randn(1, 4, 3, 3), ], convert_args=[ torch.zeros(0, 0, 0, 0), torch.zeros(0, 0, 0, 0) ]) def test_pointwise_unary(self): for op in ["relu", "sigmoid"]: with self.subTest(op): class UnaryModule(torch.nn.Module): def forward(self, arg): if op == "relu": return torch.nn.functional.relu(arg) if op == "sigmoid": return torch.sigmoid(arg) raise Exception("Bad op") self.check(UnaryModule(), torch.tensor([-1.0, 1.0])) self.check( UnaryModule(), qpt(torch.tensor([-1.0, 1.0]), 1. / 256, 0), ) def test_pointwise_binary(self): for op in ["add", "sub", "mul", "div"]: with self.subTest(op): class BinaryModule(torch.nn.Module): def forward(self, lhs, rhs): if op == "add": return lhs + rhs if op == "sub": return lhs - rhs if op == "mul": return lhs * rhs if op == "div": return lhs / rhs raise Exception("Bad op") self.check( BinaryModule(), [ torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0]), ]) self.check( BinaryModule(), [ torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0], [5.0, 6.0]]), ]) with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"): self.check( BinaryModule(), [ torch.tensor([1.0, 2.0]), torch.tensor([[3.0, 4.0], [5.0, 6.0]]), ]) def test_pointwise_binary_const(self): const = torch.randn(1, 4, 6, 6) class ArgPlusConst(torch.nn.Module): def forward(self, arg): return arg + const class ConstPlusArg(torch.nn.Module): def forward(self, arg): return const + arg arg_contig = torch.randn(2, 4, 6, 6) arg_nhwc = nhwc(torch.randn(2, 4, 6, 6)) for mod_class in [ArgPlusConst, ConstPlusArg]: for use_nhwc in [False, True]: with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc): arg = arg_nhwc if use_nhwc else arg_contig memory_format = torch.channels_last if use_nhwc else torch.contiguous_format self.check(mod_class(), arg, expected_memory_format=memory_format) def test_hardtanh(self): inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0]) self.check(torch.nn.Hardtanh(), inp) self.check(torch.nn.Hardtanh(0.0, 6.0), inp) with self.assertRaisesRegex(Exception, "hardtanh with args"): self.check(torch.nn.Hardtanh(0.0, 5.0), inp) def test_softmax(self): inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]]) self.check(torch.nn.Softmax(), inp) self.check(torch.nn.Softmax(dim=0), inp) # Test flexible size self.check( torch.nn.Softmax(), inp, convert_args=[torch.zeros(0, 0)], ) def test_to(self): class ToCPU(torch.nn.Module): def __init__(self): super().__init__() self.prelu = torch.nn.PReLU() def forward(self, x): y = x.to("cpu") # add prelu since input operand can't be output return self.prelu(y) arg = torch.randn(1, 2, 3, 3) self.check(ToCPU(), arg) # Test flexible size self.check( ToCPU(), arg, convert_args=[torch.zeros(1, 2, 0, 0)], ) def test_detach(self): class DetachModule(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): y = x.detach() return torch.nn.functional.relu(y) self.check(DetachModule(), torch.randn(1, 2, 3, 3)) self.check( DetachModule(), torch.randn(1, 2, 3, 3), convert_args=[torch.zeros(1, 2, 0, 0)]) def test_log_softmax(self): inp = torch.randn(3, 10) self.check(torch.nn.LogSoftmax(), inp) self.check(torch.nn.LogSoftmax(0), inp) def test_mean(self): class MeanModule(torch.nn.Module): def __init__(self, dim, keep=False): super().__init__() self.dim = dim self.keep = keep def forward(self, t): return torch.mean(t, dim=self.dim, keepdim=self.keep) self.check(MeanModule(0), torch.randn(2, 3)) self.check(MeanModule(1), torch.randn(2, 3)) self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6)) self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6))) self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6))) self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6))) def test_max_pool2d(self): for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): with self.subTest(name): self.check(torch.nn.MaxPool2d(2), inp) self.check(torch.nn.MaxPool2d((3, 4)), inp) self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp) def test_avg_pool2d(self): for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): with self.subTest(name): atol_rtol = None limit = None convert_dims = (2, 3, 0, 0) convert_arg = torch.zeros(*convert_dims) for model in ( torch.nn.AvgPool2d(2), torch.nn.AvgPool2d((3, 4)), torch.nn.AvgPool2d((3, 4), (1, 2))): if "quant" in name: atol_rtol = (1, 0) limit = model(inp).numel() convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128) if "nhwc" in name: convert_arg = nhwc(convert_arg) self.check(model, inp, atol_rtol=atol_rtol, limit=limit) self.check( model, inp, convert_args=[convert_arg], atol_rtol=atol_rtol, limit=limit ) def test_adaptive_avg_pool2d(self): for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): with self.subTest(name): self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp) with self.assertRaisesRegex(Exception, "with output size"): self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp) def test_upsample_nearest2d(self): convert_args = dict(self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128)) for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): with self.subTest(name): self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp) self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp) self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp) self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp) self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp) self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp) self.check( torch.nn.UpsamplingNearest2d(size=(24, 32)), inp, convert_args=[convert_args[name]] ) self.check( torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp, convert_args=[convert_args[name]] ) def test_linear(self): torch.manual_seed(29) self.check(torch.nn.Linear(16, 32), torch.randn(2, 16)) self.check( torch.nn.Linear(16, 32), torch.randn(2, 16), convert_args=[torch.zeros(0, 16)]) def test_conv2d(self): cases = [ # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name ( 4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"), # noqa: E201,E241 ( 4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"), # noqa: E201,E241 ( 4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"), # noqa: E201,E241 ( 8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"), # noqa: E201,E241 ( 4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"), # noqa: E201,E241 ( 4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"), # noqa: E201,E241 ( 8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"), # noqa: E201,E241 ] for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]: for case in cases: in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name = case with self.subTest("{}-{}".format(kind, name)): inp = torch.randn(input_dim) model = torch.nn.Conv2d(in_ch, out_ch, kernel, stride, padding, groups=groups, bias=bool(bias)) output_size = model(inp).numel() atol_rtol = None limit = None convert_dims = (0, in_ch, 0, 0) convert_arg = torch.zeros(*convert_dims) if "quant" in kind: model = torch.nn.Sequential(model) model.eval() model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') model = torch.ao.quantization.prepare(model) model(inp) model = torch.ao.quantization.convert(model) inp = qpt(inp, 1.0 / 16, 128) # I've seen numerical differences between QNNPACK and NNAPI, # but never more than 1 quantum, and never more than ~1% of # the output in this test. atol_rtol = (1, 0) limit = output_size * 0.03 convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128) if "nhwc" in kind: inp = nhwc(inp) convert_arg = nhwc(convert_arg) self.check(model, inp, atol_rtol=atol_rtol, limit=limit) self.check( model, inp, convert_args=[convert_arg], atol_rtol=atol_rtol, limit=limit ) def test_conv2d_transpose(self): torch.manual_seed(29) in_ch, out_ch, kernel = (5, 7, (2, 2)) input_dim = (4, 5, 3, 3) convert_dims = input_dim[:2] + (0, 0) for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]: with self.subTest(kind): inp = torch.randn(input_dim) model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel) output_size = model(inp).numel() atol_rtol = (0.0002, 0) limit = None convert_arg = torch.zeros(*convert_dims) if "quant" in kind: model = torch.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel) model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') inp = qpt(inp, 1.0 / 16, 128) # I've seen numerical differences between QNNPACK and NNAPI, # but never more than 1 quantum, and never more than ~10% of # the output in this test. atol_rtol = (1, 0) limit = output_size * 0.1 convert_arg = qpt(convert_arg, 1.0 / 16, 128) if "nhwc" in kind: inp = nhwc(inp) convert_arg = nhwc(convert_arg) self.check(model, inp, atol_rtol=atol_rtol, limit=limit) self.check( model, inp, convert_args=[convert_arg], atol_rtol=atol_rtol, limit=limit ) def test_qadd(self): func = torch.nn.quantized.QFunctional() func.scale = 0.5 func.zero_point = 120 class AddMod(torch.nn.Module): def forward(self, lhs, rhs): return func.add(lhs, rhs) class AddReluMod(torch.nn.Module): def forward(self, lhs, rhs): return func.add_relu(lhs, rhs) class MulMod(torch.nn.Module): def forward(self, lhs, rhs): return func.mul(lhs, rhs) for (name, mod) in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]: with self.subTest(name): self.check( mod(), [ qpt([1.0, 2.0], 0.25, 128), qpt([3.0, 4.0], 0.25, 128), ]) self.check( mod(), [ qpt([[1.0, 2.0]], 0.25, 128), qpt([[3.0, 4.0]], 0.25, 128), ], convert_args=[ qpt([[1.0, 2.0]], 0.25, 128), qpt(torch.zeros((1, 2)), 0.25, 128), ] ) self.check( mod(), [ qpt([[1.0, 2.0]], 0.25, 128), qpt([[3.0, 4.0]], 0.25, 128), ], convert_args=[ qpt(torch.zeros((1, 2)), 0.25, 128), qpt([[3.0, 4.0]], 0.25, 128), ] ) self.check( mod(), [ qpt([[1.0, 2.0]], 0.25, 128), qpt([[3.0, 4.0]], 0.25, 128), ], convert_args=[ qpt(torch.zeros((1, 2)), 0.25, 128), qpt(torch.zeros((1, 2)), 0.25, 128), ] ) # NOTE: NNAPI qadd supports broadcast, but PT does not. def test_qlinear(self): torch.manual_seed(29) weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8) bias = torch.randn(16) mod = torch.nn.quantized.Linear(32, 16) mod.set_weight_bias(weight, bias) inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8) self.check(mod, inp) def test_seblock_mul(self): class MulModel(torch.nn.Module): def forward(self, lhs, rhs): return lhs * rhs self.check( MulModel(), [ nhwc(torch.randn(2, 3, 4, 4)), torch.randn(1, 3, 1, 1), ]) def test_multi_output(self): class MultiModel(torch.nn.Module): def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]: the_sum = lhs + rhs the_diff = lhs - rhs return the_sum, the_diff self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])]) if __name__ == '__main__': run_tests()