# mypy: allow-untyped-defs from torch import nn class QuantStub(nn.Module): r"""Quantize stub module, before calibration, this is same as an observer, it will be swapped as `nnq.Quantize` in `convert`. Args: qconfig: quantization configuration for the tensor, if qconfig is not provided, we will get qconfig from parent modules """ def __init__(self, qconfig=None): super().__init__() if qconfig: self.qconfig = qconfig def forward(self, x): return x class DeQuantStub(nn.Module): r"""Dequantize stub module, before calibration, this is same as identity, this will be swapped as `nnq.DeQuantize` in `convert`. Args: qconfig: quantization configuration for the tensor, if qconfig is not provided, we will get qconfig from parent modules """ def __init__(self, qconfig=None): super().__init__() if qconfig: self.qconfig = qconfig def forward(self, x): return x class QuantWrapper(nn.Module): r"""A wrapper class that wraps the input module, adds QuantStub and DeQuantStub and surround the call to module with call to quant and dequant modules. This is used by the `quantization` utility functions to add the quant and dequant modules, before `convert` function `QuantStub` will just be observer, it observes the input tensor, after `convert`, `QuantStub` will be swapped to `nnq.Quantize` which does actual quantization. Similarly for `DeQuantStub`. """ quant: QuantStub dequant: DeQuantStub module: nn.Module def __init__(self, module): super().__init__() qconfig = getattr(module, "qconfig", None) self.add_module('quant', QuantStub(qconfig)) self.add_module('dequant', DeQuantStub(qconfig)) self.add_module('module', module) self.train(module.training) def forward(self, X): X = self.quant(X) X = self.module(X) return self.dequant(X)