# Owner(s): ["oncall: mobile"] import torch from torch.nn import functional as F from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing import FileCheck import io class TestMetalRewritePass(TestCase): @staticmethod def validate_transformed_module( # To please flake self, pattern_count_map, data_shape, prepack_removal=False, fuse_clamping_ops=False): module_instance = self scripted_model = torch.jit.script(module_instance) scripted_model.eval() input_data = torch.normal(1, 20, size=data_shape) ref_result = scripted_model(input_data) torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c) if fuse_clamping_ops or prepack_removal: scripted_model._c = torch._C._freeze_module(scripted_model._c) if fuse_clamping_ops: torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c) if prepack_removal: torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c) buffer = io.BytesIO() torch.jit.save(scripted_model, buffer) buffer.seek(0) deserialized_scripted_model = torch.jit.load(buffer) for pattern, v in pattern_count_map.items(): if (v == 0): FileCheck().check(pattern).run(deserialized_scripted_model.graph) elif (v == -1): FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) else: FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) def test_conv(self): # Conv params batch_size = 2 input_channels_per_group = 6 height = 16 width = 16 output_channels_per_group = 6 groups = 4 kernel_h = kernel_w = 3 stride_h = stride_w = 1 pad_h = pad_w = 1 dilation = 1 input_channels = input_channels_per_group * groups output_channels = output_channels_per_group * groups kernels = (kernel_h, kernel_w) strides = (stride_h, stride_w) paddings = (pad_h, pad_w) dilations = (dilation, dilation) conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) conv_bias_shape = (output_channels) class Conv2D(torch.nn.Module): def __init__(self): super(Conv2D, self).__init__() self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) self.strides = strides self.paddings = paddings self.dilations = dilations self.groups = groups def forward(self, x): return F.conv2d(x, self.weight, self.bias, self.strides, self.paddings, self.dilations, self.groups) data_shape = (batch_size, input_channels, height, width) pattern_count_map = {"Tensor = aten::conv2d": -1, "metal_prepack::conv2d_prepack": 1, "metal_prepack::conv2d_run": 1} TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) class Conv2DRelu(torch.nn.Module): def __init__(self): super(Conv2DRelu, self).__init__() self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) self.strides = strides self.paddings = paddings self.dilations = dilations self.groups = groups def forward(self, x): o = F.conv2d(x, self.weight, self.bias, self.strides, self.paddings, self.dilations, self.groups) o = F.relu(o) return o data_shape = (batch_size, input_channels, height, width) pattern_count_map = {"Tensor = aten::conv2d": -1, "metal_prepack::conv2d_prepack": 1, "metal_prepack::conv2d_run": 1} TestMetalRewritePass.validate_transformed_module( Conv2DRelu(), pattern_count_map, data_shape) pattern_count_map["aten::relu"] = 1 pattern_count_map["metal_prepack::conv2d_prepack"] = -1 TestMetalRewritePass.validate_transformed_module( Conv2DRelu(), pattern_count_map, data_shape, prepack_removal=True) pattern_count_map["aten::relu"] = -1 TestMetalRewritePass.validate_transformed_module( Conv2DRelu(), pattern_count_map, data_shape, prepack_removal=True, fuse_clamping_ops=True) class Conv2DHardtanh(torch.nn.Module): def __init__(self): super(Conv2DHardtanh, self).__init__() self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) self.strides = strides self.paddings = paddings self.dilations = dilations self.groups = groups def forward(self, x): o = F.conv2d(x, self.weight, self.bias, self.strides, self.paddings, self.dilations, self.groups) o = F.hardtanh(o) return o data_shape = (batch_size, input_channels, height, width) pattern_count_map = {"Tensor = aten::conv2d": -1, "metal_prepack::conv2d_prepack": 1, "metal_prepack::conv2d_run": 1} TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape) pattern_count_map["aten::hardtanh"] = 1 pattern_count_map["metal_prepack::conv2d_prepack"] = -1 TestMetalRewritePass.validate_transformed_module( Conv2DHardtanh(), pattern_count_map, data_shape, prepack_removal=True) pattern_count_map["aten::hardtanh"] = -1 TestMetalRewritePass.validate_transformed_module( Conv2DRelu(), pattern_count_map, data_shape, prepack_removal=True, fuse_clamping_ops=True) if __name__ == "__main__": run_tests()