# mypy: allow-untyped-defs import copy import itertools import warnings import inspect import torch import torch.nn as nn import torch.ao.nn.quantized as nnq from torch.ao.nn.intrinsic import _FusedModule from torch.ao.quantization.quantization_mappings import ( get_default_dynamic_quant_module_mappings, get_default_static_quant_module_mappings, get_default_static_quant_reference_module_mappings, get_default_qat_module_mappings, get_default_qconfig_propagation_list, no_observer_set, _has_special_act_post_process, _get_special_act_post_process, ) from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper from torch.ao.quantization.qconfig import ( _add_module_to_qconfig_obs_ctr, default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit, _activation_is_memoryless) from torch.nn.utils.parametrize import type_before_parametrizations from torch.ao.quantization.observer import _is_activation_post_process # TODO remove this once BC is no longer required to avoid a SEV from torch.ao.quantization.observer import ( # noqa: F401 _is_activation_post_process as is_activation_post_process ) __all__ = [ "get_default_custom_config_dict", "propagate_qconfig_", "add_quant_dequant", "prepare", "quantize", "quantize_dynamic", "prepare_qat", "quantize_qat", "convert", "swap_module", ] _DEFAULT_CUSTOM_CONFIG_DICT = { 'float_to_observed_custom_module_class': { nn.LSTM: nn.quantizable.LSTM, nn.MultiheadAttention: nn.quantizable.MultiheadAttention, }, 'observed_to_quantized_custom_module_class': { nn.quantizable.LSTM: nn.quantized.LSTM, nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, } } def get_default_custom_config_dict(): r"""Defines the default custom config dict. """ return _DEFAULT_CUSTOM_CONFIG_DICT def _propagate_qconfig_helper(module, qconfig_dict, qconfig_parent=None, prefix='', prepare_custom_config_dict=None): r"""This is a helper function for `propagate_qconfig_` Args: module: input module qconfig_dict: dictionary that maps from name of submodule to quantization configuration qconfig_parent: quantization config of parent module, we will fallback to this config when there is no specified config for current module prefix: corresponding prefix of the current module, used as key in qconfig_dict prepare_custom_config_dict: dictionary for custom handling of modules see docs for :func:`~torch.ao.quantization.prepare_fx` Return: None, module is modified inplace with qconfig attached """ module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent) module_qconfig = qconfig_dict.get(prefix, module_qconfig) module_qconfig = getattr(module, 'qconfig', module_qconfig) torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module) qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module) module.qconfig = qconfig_with_device_check for name, child in module.named_children(): module_prefix = prefix + '.' + name if prefix else name # do no not propagate qconfig to child if child is non traceable if prepare_custom_config_dict is None or not ( name in prepare_custom_config_dict.get("non_traceable_module_name", []) or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", []) ): _propagate_qconfig_helper( child, qconfig_dict, qconfig_with_device_check, module_prefix ) def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): r"""Propagate qconfig through the module hierarchy and assign `qconfig` attribute on each leaf module Args: module: input module qconfig_dict: dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute) prepare_custom_config_dict: dictionary for custom handling of modules see docs for :func:`~torch.ao.quantization.prepare_fx` Return: None, module is modified inplace with qconfig attached """ if qconfig_dict is None: qconfig_dict = {} if prepare_custom_config_dict is None: prepare_custom_config_dict = {} _propagate_qconfig_helper(module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict) def _observer_forward_hook(self, input, output): r"""Forward hook that calls observer on the output """ return self.activation_post_process(output) def _observer_forward_pre_hook(self, input): r"""Forward pre hook that calls observer on the output """ return self.activation_post_process(input[0]) def _register_activation_post_process_hook(module, pre_hook=False): assert hasattr(module, 'activation_post_process'), \ 'Expect activation_post_process attribute already attached to the module' if pre_hook: handle = module.register_forward_pre_hook( _observer_forward_pre_hook, prepend=True ) else: handle = module.register_forward_hook( _observer_forward_hook, prepend=True ) def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None): r"""Add observer for the leaf child of the module. This function insert observer module to all leaf child module that has a valid qconfig attribute. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize qconfig_propagation_list: a list of quantizable modules that will have observers added to them if they are leaf nodes device: parent device, if any non_leaf_module_list: list of non-leaf modules we want to add observer Return: None, module is modified inplace with added observer modules and forward_hooks """ if qconfig_propagation_list is None: qconfig_propagation_list = get_default_qconfig_propagation_list() if custom_module_class_mapping is None: custom_module_class_mapping = {} # respect device affinity when adding observers if device is None: devices = _get_unique_devices_(module) assert len(devices) <= 1, ( f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" ) device = next(iter(devices)) if len(devices) > 0 else None def get_activation_post_process(qconfig, device, special_act_post_process=None): activation = qconfig.activation() if special_act_post_process is None else special_act_post_process() if device is not None: activation.to(device) return activation def needs_observation(m): return hasattr(m, 'qconfig') and m.qconfig is not None def insert_activation_post_process(m, special_act_post_process=None): """ Adds an activation post process module and register a pre or post hook that calls the module """ # We don't insert observer/fake_quantize for DeQuantStub if needs_observation(m) and not isinstance(m, DeQuantStub): # observer and hook will be gone after we swap the module m.add_module('activation_post_process', get_activation_post_process( m.qconfig, device, special_act_post_process)) # Register observer as the first entry in the hook list # All post forward hooks are preserved and will be executed after the observer before convert _register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig)) for name, child in module.named_children(): # TODO remove Dropout special after codebase stable if type_before_parametrizations(child) in [nn.Dropout]: continue elif issubclass(type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)): if needs_observation(child): assert hasattr(child, "activation_post_process"), ( f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" ) child.activation_post_process = get_activation_post_process(child.qconfig, device) elif isinstance(child, _FusedModule): # activation_post_process are now added directly to nn.Sequential/_FusedModule if needs_observation(child): insert_activation_post_process(child) elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list: if needs_observation(child): insert_activation_post_process(child) elif _has_special_act_post_process(child): special_act_post_process = _get_special_act_post_process(child) insert_activation_post_process(child, special_act_post_process) elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping: observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child) setattr(module, name, observed_child) # TODO: These are the modules that cannot be observed # Once there are more, we should move them to a separate list if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set(): insert_activation_post_process(observed_child) else: _add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping) # Insert observers only for leaf nodes, note that this observer is for # the output of the module, for input QuantStub will observe them if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \ and type_before_parametrizations(module) in qconfig_propagation_list: insert_activation_post_process(module) # This is a special case for AdaRound eager mode # AdaRound contains weight_fake_quant to be propagated from API to convert # leaf node check with a number of children looks naive assumption that blocks # Adding an exception case for AdaRound if hasattr(module, "weight_fake_quant") and not isinstance(module, torch.nn.Sequential) \ and type_before_parametrizations(module) in qconfig_propagation_list: insert_activation_post_process(module) def _get_unique_devices_(module): return {p.device for p in module.parameters()} | \ {p.device for p in module.buffers()} def add_quant_dequant(module): r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it can return a new module which wraps the input module as well. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize Return: Either the inplace modified module with submodules wrapped in `QuantWrapper` based on qconfig or a new `QuantWrapper` module which wraps the input module, the latter case only happens when the input module is a leaf module and we want to quantize it. """ if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig: return QuantWrapper(module) for name, child in module.named_children(): module._modules[name] = add_quant_dequant(child) return module def prepare(model, inplace=False, allow_list=None, observer_non_leaf_module_list=None, prepare_custom_config_dict=None): r"""Prepares a copy of the model for quantization calibration or quantization-aware training. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. The model will be attached with observer or fake quant modules, and qconfig will be propagated. Args: `model`: input model to be modified in-place `inplace`: carry out model transformations in-place, the original module is mutated `allow_list`: list of quantizable modules `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer `prepare_custom_config_dict`: customization configuration dictionary for prepare function .. code-block:: python # Example of prepare_custom_config_dict: prepare_custom_config_dict = { # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module "float_to_observed_custom_module_class": { CustomModule: ObservedCustomModule } } """ torch._C._log_api_usage_once("quantization_api.quantize.prepare") if prepare_custom_config_dict is None: prepare_custom_config_dict = get_default_custom_config_dict() custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) if not inplace: model = copy.deepcopy(model) # TODO: remove allow_list qconfig_propagation_list = allow_list if allow_list is None: qconfig_propagation_list = get_default_qconfig_propagation_list() propagate_qconfig_(model, qconfig_dict=None) # sanity check common API misusage if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): warnings.warn("None of the submodule got qconfig applied. Make sure you " "passed correct configuration through `qconfig_dict` or " "by assigning the `.qconfig` attribute directly on submodules") _add_observer_( model, qconfig_propagation_list, observer_non_leaf_module_list, custom_module_class_mapping=custom_module_class_mapping) return model def _remove_activation_post_process(module): # TODO: maybe we should change activation_post_process to _activation_post_process # to prevent it from being used by user if hasattr(module, 'activation_post_process') and \ _is_activation_post_process(module.activation_post_process): delattr(module, 'activation_post_process') # remove activation_post_process pre and post hooks def remove_hooks(pre_hook=False): hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks observer_hook = _observer_forward_pre_hook if pre_hook else _observer_forward_hook handle_ids_to_remove = set() for handle_id, hook_fn in hook_map.items(): if hook_fn is observer_hook: handle_ids_to_remove.add(handle_id) for handle_id in handle_ids_to_remove: hook_map.pop(handle_id) remove_hooks(pre_hook=True) remove_hooks(pre_hook=False) # TODO: rename to something more general def _remove_qconfig(module): r"""Clean up the qconfig left in the module so that new qconfig can be propagated. Args: module: module to be cleaned up """ for child in module.children(): _remove_qconfig(child) if hasattr(module, "qconfig"): del module.qconfig _remove_activation_post_process(module) def quantize(model, run_fn, run_args, mapping=None, inplace=False): r"""Quantize the input float model with post training static quantization. First it will prepare the model for calibration, then it calls `run_fn` which will run the calibration step, after that we will convert the model to a quantized model. Args: model: input float model run_fn: a calibration function for calibrating the prepared model run_args: positional arguments for `run_fn` inplace: carry out model transformations in-place, the original module is mutated mapping: correspondence between original module types and quantized counterparts Return: Quantized model. """ torch._C._log_api_usage_once("quantization_api.quantize.quantize") if mapping is None: mapping = get_default_static_quant_module_mappings() if not inplace: model = copy.deepcopy(model) model.eval() prepare(model, inplace=True) run_fn(model, *run_args) convert(model, mapping, inplace=True) return model def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False): r"""Converts a float model to dynamic (i.e. weights-only) quantized model. Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization by default is performed for layers with large weights size - i.e. Linear and RNN variants. Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. If `qconfig` is provided, the `dtype` argument is ignored. Args: model: input model qconfig_spec: Either: - A dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute). Entries in the dictionary need to be QConfig instances. - A set of types and/or submodule names to apply dynamic quantization to, in which case the `dtype` argument is used to specify the bit-width inplace: carry out model transformations in-place, the original module is mutated mapping: maps type of a submodule to a type of corresponding dynamically quantized version with which the submodule needs to be replaced """ torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") if qconfig_spec is None: if dtype == torch.qint8: qconfig_spec = { nn.Linear : default_dynamic_qconfig, nn.LSTM : default_dynamic_qconfig, nn.GRU : default_dynamic_qconfig, nn.LSTMCell : default_dynamic_qconfig, nn.RNNCell : default_dynamic_qconfig, nn.GRUCell : default_dynamic_qconfig, } elif dtype == torch.float16: qconfig_spec = { nn.Linear : float16_dynamic_qconfig, nn.LSTM : float16_dynamic_qconfig, nn.GRU : float16_dynamic_qconfig, nn.LSTMCell : float16_dynamic_qconfig, nn.RNNCell : float16_dynamic_qconfig, nn.GRUCell : float16_dynamic_qconfig, } elif dtype == torch.quint8: qconfig_spec = { nn.EmbeddingBag : float_qparams_weight_only_qconfig, nn.Embedding : float_qparams_weight_only_qconfig, } elif dtype == torch.quint4x2: qconfig_spec = { nn.EmbeddingBag : float_qparams_weight_only_qconfig_4bit, } else: raise ValueError( f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please") elif isinstance(qconfig_spec, set): if dtype is torch.qint8: default_qconfig = default_dynamic_qconfig elif dtype is torch.float16: default_qconfig = float16_dynamic_qconfig elif dtype is torch.quint8: default_qconfig = float_qparams_weight_only_qconfig elif dtype is torch.quint4x2: default_qconfig = float_qparams_weight_only_qconfig_4bit else: raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype)) qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) if mapping is None: mapping = get_default_dynamic_quant_module_mappings() if not inplace: model = copy.deepcopy(model) model.eval() propagate_qconfig_(model, qconfig_spec) convert(model, mapping, inplace=True) return model def prepare_qat(model, mapping=None, inplace=False): r""" Prepares a copy of the model for quantization calibration or quantization-aware training and converts it to quantized version. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. Args: model: input model to be modified in-place mapping: dictionary that maps float modules to quantized modules to be replaced. inplace: carry out model transformations in-place, the original module is mutated """ torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") assert model.training, "prepare_qat only works on models in training mode" if mapping is None: mapping = get_default_qat_module_mappings() if not inplace: model = copy.deepcopy(model) propagate_qconfig_(model, qconfig_dict=None) convert(model, mapping=mapping, inplace=True, remove_qconfig=False) prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) return model def quantize_qat(model, run_fn, run_args, inplace=False): r"""Do quantization aware training and output a quantized model Args: model: input model run_fn: a function for evaluating the prepared model, can be a function that simply runs the prepared model or a training loop run_args: positional arguments for `run_fn` Return: Quantized model. """ torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") if not inplace: model = copy.deepcopy(model) model.train() prepare_qat(model, inplace=True) run_fn(model, *run_args) convert(model, inplace=True) return model def convert( module, mapping=None, inplace=False, remove_qconfig=True, is_reference=False, convert_custom_config_dict=None, use_precomputed_fake_quant=False): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class. And remove qconfig at the end if remove_qconfig is set to True. Args: `module`: prepared and calibrated module `mapping`: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules `inplace`: carry out model transformations in-place, the original module is mutated `convert_custom_config_dict`: custom configuration dictionary for convert function `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant .. code-block:: python # Example of convert_custom_config_dict: convert_custom_config_dict = { # user will manually define the corresponding quantized # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { ObservedCustomModule: QuantizedCustomModule } } """ torch._C._log_api_usage_once("quantization_api.quantize.convert") if not inplace: module = copy.deepcopy(module) _convert( module, mapping, inplace=True, is_reference=is_reference, convert_custom_config_dict=convert_custom_config_dict, use_precomputed_fake_quant=use_precomputed_fake_quant) if remove_qconfig: _remove_qconfig(module) return module def _convert( module, mapping=None, inplace=False, is_reference=False, convert_custom_config_dict=None, use_precomputed_fake_quant=False): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class Args: module: input module mapping: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules inplace: carry out model transformations in-place, the original module is mutated is_reference: a flag to enable quantized reference module use_precomputed_fake_quant: a flag to enable use of precomputed fake quant """ if mapping is None: mapping = get_default_static_quant_reference_module_mappings() if is_reference \ else get_default_static_quant_module_mappings() if convert_custom_config_dict is None: convert_custom_config_dict = get_default_custom_config_dict() custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {}) if not inplace: module = copy.deepcopy(module) reassign = {} for name, mod in module.named_children(): # both fused modules and observed custom modules are # swapped as one unit if not isinstance(mod, _FusedModule) and \ type_before_parametrizations(mod) not in custom_module_class_mapping: _convert(mod, mapping, True, # inplace is_reference, convert_custom_config_dict, use_precomputed_fake_quant=use_precomputed_fake_quant) reassign[name] = swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant) for key, value in reassign.items(): module._modules[key] = value return module def swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False): r"""Swaps the module if it has a quantized counterpart and it has an `observer` attached. Args: mod: input module mapping: a dictionary that maps from nn module to nnq module Return: The corresponding quantized module of `mod` """ new_mod = mod if hasattr(mod, 'qconfig') and mod.qconfig is not None: swapped = False if type_before_parametrizations(mod) in custom_module_class_mapping: new_mod = custom_module_class_mapping[type_before_parametrizations(mod)].from_observed(mod) swapped = True elif type_before_parametrizations(mod) in mapping: qmod = mapping[type_before_parametrizations(mod)] if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE: assert mod.qconfig is not None weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) weight_qparams = get_qparam_dict(weight_post_process) new_mod = qmod.from_float(mod, weight_qparams) else: sig = inspect.signature(qmod.from_float) if 'use_precomputed_fake_quant' in sig.parameters: new_mod = qmod.from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) else: new_mod = qmod.from_float(mod) swapped = True if swapped: # Preserve module's pre forward hooks. They'll be called on quantized input for pre_hook_fn in mod._forward_pre_hooks.values(): new_mod.register_forward_pre_hook(pre_hook_fn) # Preserve module's post forward hooks except _observer_forward_hook # After convert they'll work with quantized output for hook_fn in mod._forward_hooks.values(): if hook_fn is not _observer_forward_hook: new_mod.register_forward_hook(hook_fn) # respect device affinity when swapping modules devices = _get_unique_devices_(mod) assert len(devices) <= 1, ( f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" ) device = next(iter(devices)) if len(devices) > 0 else None if device: new_mod.to(device) return new_mod def _get_observer_dict(mod, target_dict, prefix=""): r"""Traverse the modules and save all observers into dict. This is mainly used for quantization accuracy debug Args: mod: the top module we want to save all observers prefix: the prefix for the current module target_dict: the dictionary used to save all the observers """ def get_prefix(prefix): return prefix if prefix == "" else prefix + '.' if hasattr(mod, 'activation_post_process'): target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process for name, child in mod.named_children(): module_prefix = get_prefix(prefix) + name if prefix else name _get_observer_dict(child, target_dict, module_prefix)