# mypy: allow-untyped-defs import torch from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten from .module_tracker import ModuleTracker from typing import List, Any, Dict, Optional, Union, Tuple, Iterator from collections import defaultdict from torch.utils._python_dispatch import TorchDispatchMode from torch._decomp import register_decomposition from math import prod from functools import wraps import warnings __all__ = ["FlopCounterMode", "register_flop_formula"] aten = torch.ops.aten def get_shape(i): if isinstance(i, torch.Tensor): return i.shape return i flop_registry: Dict[Any, Any] = {} def shape_wrapper(f): @wraps(f) def nf(*args, out_val=None, **kwargs): args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val)) return f(*args, out_shape=out_shape, **kwargs) return nf def register_flop_formula(targets, get_raw=False): def register_fun(flop_formula): if not get_raw: flop_formula = shape_wrapper(flop_formula) register_decomposition(targets, registry=flop_registry, unsafe=True)(flop_formula) return flop_formula return register_fun @register_flop_formula(aten.mm) def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int: """Count flops for matmul.""" # Inputs should be a list of length 2. # Inputs contains the shapes of two matrices. m, k = a_shape k2, n = b_shape assert k == k2 # NB(chilli): Should be 2 * k - 1 technically for FLOPs. return m * n * 2 * k @register_flop_formula(aten.addmm) def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: """Count flops for addmm.""" return mm_flop(a_shape, b_shape) @register_flop_formula(aten.bmm) def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int: """Count flops for the bmm operation.""" # Inputs should be a list of length 2. # Inputs contains the shapes of two tensor. b, m, k = a_shape b2, k2, n = b_shape assert b == b2 assert k == k2 # NB(chilli): Should be 2 * k - 1 technically for FLOPs. flop = b * m * n * 2 * k return flop @register_flop_formula(aten.baddbmm) def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: """Count flops for the baddbmm operation.""" # Inputs should be a list of length 3. # Inputs contains the shapes of three tensors. return bmm_flop(a_shape, b_shape) def conv_flop_count( x_shape: List[int], w_shape: List[int], out_shape: List[int], transposed: bool = False, ) -> int: """Count flops for convolution. Note only multiplication is counted. Computation for bias are ignored. Flops for a transposed convolution are calculated as flops = (x_shape[2:] * prod(w_shape) * batch_size). Args: x_shape (list(int)): The input shape before convolution. w_shape (list(int)): The filter shape. out_shape (list(int)): The output shape after convolution. transposed (bool): is the convolution transposed Returns: int: the number of flops """ batch_size = x_shape[0] conv_shape = (x_shape if transposed else out_shape)[2:] c_out, c_in, *filter_size = w_shape """ General idea here is that for a regular conv, for each point in the output spatial dimension we convolve the filter with something (hence `prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by 1. batch_size, 2. the cross product of input and weight channels. For the transpose, it's not each point in the *output* spatial dimension but each point in the *input* spatial dimension. """ # NB(chilli): I don't think this properly accounts for padding :think: # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs. flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2 return flop @register_flop_formula([aten.convolution, aten._convolution]) def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int: """Count flops for convolution.""" return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) @register_flop_formula(aten.convolution_backward) def conv_backward_flop( grad_out_shape, x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, _output_padding, _groups, output_mask, out_shape) -> int: def t(shape): return [shape[1], shape[0]] + list(shape[2:]) flop_count = 0 """ Let's say we have a regular 1D conv {A, B, C} [inp] {i, j} [weight] => (conv) {Ai + Bj, Bi + Cj} [out] And as a reminder, the transposed conv of the above is => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out] For the backwards of conv, we now have {D, E} [grad_out] {A, B, C} [inp] {i, j} [weight] # grad_inp as conv_transpose(grad_out, weight) Let's first compute grad_inp. To do so, we can simply look at all the multiplications that each element of inp is involved in. For example, A is only involved in the first element of the output (and thus only depends upon D in grad_out), and C is only involved in the last element of the output (and thus only depends upon E in grad_out) {Di, Dj + Ei, Ej} [grad_inp] Note that this corresponds to the below conv_transpose. This gives us the output_mask[0] branch, which is grad_inp. {D, E} [inp (grad_out)] {i, j} [weight] => (conv_transpose) {Di, Dj + Ei, Ej} [out (grad_inp)] I leave the fact that grad_inp for a transposed conv is just conv(grad_out, weight) as an exercise for the reader. # grad_weight as conv(inp, grad_out) To compute grad_weight, we again look at the terms in the output, which as a reminder is: => {Ai + Bj, Bi + Cj} [out] => {D, E} [grad_out] If we manually compute the gradient for the weights, we see it's {AD + BE, BD + CE} [grad_weight] This corresponds to the below conv {A, B, C} [inp] {D, E} [weight (grad_out)] => (conv) {AD + BE, BD + CE} [out (grad_weight)] # grad_weight of transposed conv as conv(grad_out, inp) As a reminder, the terms of the output of a transposed conv are: => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out] => {D, E, F, G} [grad_out] Manually computing the gradient for the weights, we see it's {AD + BE + CF, AE + BF + CG} [grad_weight] This corresponds to the below conv {D, E, F, G} [inp (grad_out)] {A, B, C} [weight (inp)] => (conv) {AD + BE + CF, AE + BF + CG} [out (grad_weight)] For the full backwards formula, there are also some details involving transpose of the batch/channel dimensions and groups, but I skip those for the sake of brevity (and they're pretty similar to matmul backwards) Check [conv backwards decomposition as conv forwards] """ # grad_inp as conv_transpose(grad_out, weight) if output_mask[0]: grad_input_shape = get_shape(out_shape[0]) flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed) if output_mask[1]: grad_weight_shape = get_shape(out_shape[1]) if transposed: # grad_weight of transposed conv as conv(grad_out, inp) flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False) else: # grad_weight as conv(inp, grad_out) flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False) return flop_count def sdpa_flop_count(query_shape, key_shape, value_shape): """ Count flops for self-attention. NB: We can assume that value_shape == key_shape """ b, h, s_q, d_q = query_shape _b2, _h2, s_k, _d2 = key_shape _b3, _h3, _s3, d_v = value_shape assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2 total_flops = 0 # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k] total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k)) # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v] total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v)) return total_flops @register_flop_formula([aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention]) def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: """Count flops for self-attention.""" # NB: We aren't accounting for causal attention here return sdpa_flop_count(query_shape, key_shape, value_shape) def _unpack_flash_attention_nested_shapes( *, query, key, value, grad_out=None, cum_seq_q, cum_seq_k, max_q, max_k, ) -> Iterator[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Optional[Tuple[int, ...]]]]: """ Given inputs to a flash_attention_(forward|backward) kernel, this will handle behavior for NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for each batch element. In the case that this isn't a NestedTensor kernel, then it just yields the original shapes. """ if cum_seq_q is not None: # This means we should be dealing with a Nested Jagged Tensor query. # The inputs will have shape (sum(sequence len), heads, dimension) # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension) # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension) # So the flops calculation in this case is an overestimate of the actual flops. assert len(key.shape) == 3 assert len(value.shape) == 3 assert grad_out is None or grad_out.shape == query.shape _, h_q, d_q = query.shape _, h_k, d_k = key.shape _, h_v, d_v = value.shape assert cum_seq_q is not None assert cum_seq_k is not None assert cum_seq_q.shape == cum_seq_k.shape seq_q_lengths = (cum_seq_q[1:] - cum_seq_q[:-1]).tolist() seq_k_lengths = (cum_seq_k[1:] - cum_seq_k[:-1]).tolist() for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths): new_query_shape = (1, h_q, seq_q_len, d_q) new_key_shape = (1, h_k, seq_k_len, d_k) new_value_shape = (1, h_v, seq_k_len, d_v) new_grad_out_shape = new_query_shape if grad_out is not None else None yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape return yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None def _unpack_efficient_attention_nested_shapes( *, query, key, value, grad_out=None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, ) -> Iterator[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Optional[Tuple[int, ...]]]]: """ Given inputs to a efficient_attention_(forward|backward) kernel, this will handle behavior for NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for each batch element. In the case that this isn't a NestedTensor kernel, then it just yields the original shapes. """ if cu_seqlens_q is not None: # Unlike flash_attention_forward, we get a 4D tensor instead of a 3D tensor for efficient attention. # # This means we should be dealing with a Nested Jagged Tensor query. # The inputs will have shape (sum(sequence len), heads, dimension) # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension) # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension) # So the flops calculation in this case is an overestimate of the actual flops. assert len(key.shape) == 4 assert len(value.shape) == 4 assert grad_out is None or grad_out.shape == query.shape _, _, h_q, d_q = query.shape _, _, h_k, d_k = key.shape _, _, h_v, d_v = value.shape assert cu_seqlens_q is not None assert cu_seqlens_k is not None assert cu_seqlens_q.shape == cu_seqlens_k.shape seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist() seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).tolist() for len_q, len_k in zip(seqlens_q, seqlens_k): new_query_shape = (1, h_q, len_q, d_q) new_key_shape = (1, h_k, len_k, d_k) new_value_shape = (1, h_v, len_k, d_v) new_grad_out_shape = new_query_shape if grad_out is not None else None yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape return yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None @register_flop_formula(aten._flash_attention_forward, get_raw=True) def _flash_attention_forward_flop( query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, *args, out_shape=None, **kwargs ) -> int: """Count flops for self-attention.""" # NB: We aren't accounting for causal attention here # in case this is a nested tensor, we unpack the individual batch elements # and then sum the flops per batch element sizes = _unpack_flash_attention_nested_shapes( query=query, key=key, value=value, cum_seq_q=cum_seq_q, cum_seq_k=cum_seq_k, max_q=max_q, max_k=max_k, ) return sum( sdpa_flop_count(query_shape, key_shape, value_shape) for query_shape, key_shape, value_shape, _ in sizes ) @register_flop_formula(aten._efficient_attention_forward, get_raw=True) def _efficient_attention_forward_flop( query, key, value, bias, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, *args, **kwargs ) -> int: """Count flops for self-attention.""" # NB: We aren't accounting for causal attention here # in case this is a nested tensor, we unpack the individual batch elements # and then sum the flops per batch element sizes = _unpack_efficient_attention_nested_shapes( query=query, key=key, value=value, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, ) return sum( sdpa_flop_count(query_shape, key_shape, value_shape) for query_shape, key_shape, value_shape, _ in sizes ) def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape): total_flops = 0 b, h, s_q, d_q = query_shape _b2, _h2, s_k, _d2 = key_shape _b3, _h3, _s3, d_v = value_shape _b4, _h4, _s4, _d4 = grad_out_shape assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2 assert d_v == _d4 and s_k == _s3 and s_q == _s4 total_flops = 0 # Step 1: We recompute the scores matrix. # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k] total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k)) # Step 2: We propagate the gradients through the score @ v operation. # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k] total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k)) # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v] total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v)) # Step 3: We propagate th gradients through the k @ v operation # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q] total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q)) # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k] total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k)) return total_flops @register_flop_formula([aten._scaled_dot_product_efficient_attention_backward, aten._scaled_dot_product_flash_attention_backward]) def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: """Count flops for self-attention backward.""" return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) @register_flop_formula(aten._flash_attention_backward, get_raw=True) def _flash_attention_backward_flop( grad_out, query, key, value, out, # named _out_shape to avoid kwarg collision with out_shape created in wrapper logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, *args, **kwargs, ) -> int: # in case this is a nested tensor, we unpack the individual batch elements # and then sum the flops per batch element shapes = _unpack_flash_attention_nested_shapes( query=query, key=key, value=value, grad_out=grad_out, cum_seq_q=cum_seq_q, cum_seq_k=cum_seq_k, max_q=max_q, max_k=max_k, ) return sum( sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) for query_shape, key_shape, value_shape, grad_out_shape in shapes ) @register_flop_formula(aten._efficient_attention_backward, get_raw=True) def _efficient_attention_backward_flop( grad_out, query, key, value, bias, out, # named _out to avoid kwarg collision with out created in wrapper cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, *args, **kwargs, ) -> int: # in case this is a nested tensor, we unpack the individual batch elements # and then sum the flops per batch element shapes = _unpack_efficient_attention_nested_shapes( query=query, key=key, value=value, grad_out=grad_out, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, ) return sum( sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) for query_shape, key_shape, value_shape, grad_out_shape in shapes ) flop_registry = { aten.mm: mm_flop, aten.addmm: addmm_flop, aten.bmm: bmm_flop, aten.baddbmm: baddbmm_flop, aten.convolution: conv_flop, aten._convolution: conv_flop, aten.convolution_backward: conv_backward_flop, aten._scaled_dot_product_efficient_attention: sdpa_flop, aten._scaled_dot_product_flash_attention: sdpa_flop, aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop, aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop, aten._flash_attention_forward: _flash_attention_forward_flop, aten._efficient_attention_forward: _efficient_attention_forward_flop, aten._flash_attention_backward: _flash_attention_backward_flop, aten._efficient_attention_backward: _efficient_attention_backward_flop, } def normalize_tuple(x): if not isinstance(x, tuple): return (x,) return x # Define the suffixes for different orders of magnitude suffixes = ["", "K", "M", "B", "T"] # Thanks BingChat! def get_suffix_str(number): # Find the index of the appropriate suffix based on the number of digits # with some additional overflow. # i.e. 1.01B should be displayed as 1001M, not 1.001B index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3)) return suffixes[index] def convert_num_with_suffix(number, suffix): index = suffixes.index(suffix) # Divide the number by 1000^index and format it to two decimal places value = f"{number / 1000 ** index:.3f}" # Return the value and the suffix as a string return value + suffixes[index] def convert_to_percent_str(num, denom): if denom == 0: return "0%" return f"{num / denom:.2%}" def _pytreeify_preserve_structure(f): @wraps(f) def nf(args): flat_args, spec = tree_flatten(args) out = f(*flat_args) return tree_unflatten(out, spec) return nf class FlopCounterMode(TorchDispatchMode): """ ``FlopCounterMode`` is a context manager that counts the number of flops within its context. It does this using a ``TorchDispatchMode``. It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction. If you do not need hierarchical output, you do not need to use it with a module. Example usage .. code-block:: python mod = ... with FlopCounterMode(mod) as flop_counter: mod.sum().backward() """ def __init__( self, mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None, depth: int = 2, display: bool = True, custom_mapping: Optional[Dict[Any, Any]] = None): self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int)) self.depth = depth self.display = display if custom_mapping is None: custom_mapping = {} if mods is not None: warnings.warn("mods argument is not needed anymore, you can stop passing it", stacklevel=2) self.flop_registry = { **flop_registry, **{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()} } self.mod_tracker = ModuleTracker() def get_total_flops(self) -> int: return sum(self.flop_counts['Global'].values()) def get_flop_counts(self) -> Dict[str, Dict[Any, int]]: """Return the flop counts as a dictionary of dictionaries. The outer dictionary is keyed by module name, and the inner dictionary is keyed by operation name. Returns: Dict[str, Dict[Any, int]]: The flop counts as a dictionary. """ return {k: dict(v) for k, v in self.flop_counts.items()} def get_table(self, depth=None): if depth is None: depth = self.depth if depth is None: depth = 999999 import tabulate tabulate.PRESERVE_WHITESPACE = True header = ["Module", "FLOP", "% Total"] values = [] global_flops = self.get_total_flops() global_suffix = get_suffix_str(global_flops) is_global_subsumed = False def process_mod(mod_name, depth): nonlocal is_global_subsumed total_flops = sum(self.flop_counts[mod_name].values()) is_global_subsumed |= total_flops >= global_flops padding = " " * depth values = [] values.append([ padding + mod_name, convert_num_with_suffix(total_flops, global_suffix), convert_to_percent_str(total_flops, global_flops) ]) for k, v in self.flop_counts[mod_name].items(): values.append([ padding + " - " + str(k), convert_num_with_suffix(v, global_suffix), convert_to_percent_str(v, global_flops) ]) return values for mod in sorted(self.flop_counts.keys()): if mod == 'Global': continue mod_depth = mod.count(".") + 1 if mod_depth > depth: continue cur_values = process_mod(mod, mod_depth - 1) values.extend(cur_values) # We do a bit of messing around here to only output the "Global" value # if there are any FLOPs in there that aren't already fully contained by # a module. if 'Global' in self.flop_counts and not is_global_subsumed: for idx, value in enumerate(values): values[idx][0] = " " + values[idx][0] values = process_mod('Global', 0) + values if len(values) == 0: values = [["Global", "0", "0%"]] return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right")) def __enter__(self): self.flop_counts.clear() self.mod_tracker.__enter__() super().__enter__() return self def __exit__(self, *args): super().__exit__(*args) self.mod_tracker.__exit__() if self.display: print(self.get_table(self.depth)) def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} out = func(*args, **kwargs) return self._count_flops(func._overloadpacket, out, args, kwargs) def _count_flops(self, func_packet, out, args, kwargs): if func_packet in self.flop_registry: flop_count_func = self.flop_registry[func_packet] flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator] for par in set(self.mod_tracker.parents): self.flop_counts[par][func_packet] += flop_count return out