import torch from ..modules import Module from . import comm from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union, cast from torch._utils import _get_device_index from collections import OrderedDict if TYPE_CHECKING: from torch.jit import ScriptModule from torch.jit._state import EnabledProxy __all__ = ['replicate'] def _is_script_module(module: Module) -> bool: import torch.jit return isinstance(module, torch.jit.ScriptModule) def _is_script_method(module: Module) -> bool: import torch.jit return isinstance(module, torch._C.ScriptMethod) def _init_script_module() -> "ScriptModule": import torch.jit return torch.jit.ScriptModule() def _is_jit_enabled() -> "EnabledProxy": import torch.jit._state return torch.jit._state._enabled # Check if we can safely replicate the module. # there are two types of module: # 1. python modules # 2. ScriptModule # # currently a module cannot be replicated properly if the descendants of # any ScriptModule contains python module (type 1 above) def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool: # module.modules() contains module itself as the first element def descendant_modules(module: Module) -> Iterator[Module]: gen = module.modules() next(gen) return gen if not _is_jit_enabled(): return True if memo is None: memo = set() # memoize visited modules memo.add(module) if _is_script_module(module): memo.update(descendant_modules(module)) return all(_is_script_module(descendant) for descendant in descendant_modules(module)) for child in module.children(): # since any unreplicatable module will cause the check to return # False early, visited modules here can be safely ignored. if child in memo: continue if not _replicatable_module(child, memo): return False return True def _broadcast_coalesced_reshape( tensors: Sequence[torch.Tensor], devices: Sequence[Union[int, torch.device]], detach: bool = False, ) -> List[List[torch.Tensor]]: from ._functions import Broadcast if detach: return comm.broadcast_coalesced(tensors, devices) else: # Use the autograd function to broadcast if not detach if len(tensors) > 0: tensor_copies = Broadcast.apply(devices, *tensors) return [tensor_copies[i:i + len(tensors)] for i in range(0, len(tensor_copies), len(tensors))] else: return [] T = TypeVar("T", bound=Module) def replicate( network: T, devices: Sequence[Union[int, torch.device]], detach: bool = False, ) -> List[T]: if not _replicatable_module(network): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") if not devices: return [] devices = [_get_device_index(x, True) for x in devices] num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = _broadcast_coalesced_reshape(params, devices, detach) buffers = list(network.buffers()) buffers_rg: List[torch.Tensor] = [] buffers_not_rg: List[torch.Tensor] = [] for buf in buffers: if buf.requires_grad and not detach: buffers_rg.append(buf) else: buffers_not_rg.append(buf) buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) modules = list(network.modules()) module_copies: List[List[Module]] = [[] for _ in devices] module_indices: Dict[Module, int] = {} for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): replica = module._replicate_for_data_parallel() # This is a temporary fix for DDP. DDP needs to access the # replicated model parameters. It used to do so through # `mode.parameters()`. The fix added in #33907 for DP stops the # `parameters()` API from exposing the replicated parameters. # Hence, we add a `_former_parameters` dict here to support DDP. replica._former_parameters = OrderedDict() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] setattr(replica, key, module_copies[j][module_idx]) for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] param_copy = param_copies[j][param_idx] # parameters in replicas are no longer leaves, # so setattr them as non-parameter attributes setattr(replica, key, param_copy) # expose the parameter for DDP replica._former_parameters[key] = param_copy for key, buf in module._buffers.items(): # type: ignore[assignment] if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: if buf.requires_grad and not detach: buffer_copies = buffer_copies_rg buffer_idx = buffer_indices_rg[buf] else: buffer_copies = buffer_copies_not_rg buffer_idx = buffer_indices_not_rg[buf] for j in range(num_replicas): replica = module_copies[j][i] setattr(replica, key, buffer_copies[j][buffer_idx]) return [cast(T, module_copies[j][0]) for j in range(num_replicas)]