import torch from ... import jit from ... import language as tl from ... import next_power_of_2 def num_warps(n): if n <= 128: return 1 if n <= 256: return 2 if n <= 512: return 4 if n <= 4096: return 8 return 16 @jit def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, # R, extent, stride_zr, stride_hr, # relative attention scale, is_causal, # ROW_SIZE: tl.constexpr, # BLOCK_SIZE: tl.constexpr, # IS_DENSE: tl.constexpr # ): h = tl.program_id(0) m = tl.program_id(1) z = tl.program_id(2) # create index ranges hm = h * tl.num_programs(1) + m lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE # extract information from LUT header = LUT + (hm // BLOCK_SIZE) * 2 size = tl.load(header + 0) offset = tl.load(header + 1) # pointer offset off_a = z * stride_xz off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx # do not need to read column indices in the dense case if IS_DENSE: ns = tl.arange(0, ROW_SIZE) else: off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0) ns = start_n * BLOCK_SIZE + lane_n # load X mask = block_n < size a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf")) a = a.to(tl.float32) # compute out = a out *= scale # apply relative attention if R is not None: R += z * stride_zr R += h * stride_hr off_lo = (extent - m - 1) + ns mask_lo = (off_lo >= 0) & (off_lo < extent) rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0) out += rel_logits out = out.to(tl.float32) # apply causal mask out = tl.where((ns > m) & is_causal, -float("inf"), out) # computation out = tl.softmax(out) # write-back tl.store(Out + off_a + lane_n, out, mask=mask) @jit def _blocksparse_softmax_bwd(DA, stride_zdx, # DOut, stride_zdout, # Out, stride_zout, # scale, # LUT, # DR, extent, stride_zr, stride_hr, stride_er, # is_causal, # ROW_SIZE: tl.constexpr, # BLOCK_SIZE: tl.constexpr, # IS_DENSE: tl.constexpr): h = tl.program_id(0) m = tl.program_id(1) z = tl.program_id(2) # create index ranges hm = h * tl.num_programs(1) + m lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE # extract information from LUT header = LUT + (hm // BLOCK_SIZE) * 2 size = tl.load(header + 0) offset = tl.load(header + 1) # row-col offset off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE mask = block_n < size # pointers As = Out + z * stride_zout + off_mn DOuts = DOut + z * stride_zdout + off_mn # do not need to read column indices in the dense case if IS_DENSE: ns = tl.arange(0, ROW_SIZE) else: off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0) ns = start_n * BLOCK_SIZE + lane_n # load data a = tl.load(As + lane_n, mask=mask, other=0.0) a = a.to(tl.float32) dout = tl.load(DOuts + lane_n, mask=mask, other=0.0) dout = dout.to(tl.float32) # compute a = tl.where((ns > m) & is_causal & (a == a), 0., a) da = a * (dout - tl.sum(a * dout, 0)) # apply relative attention if DR is not None: DR += z * stride_zr DR += h * stride_hr off_lo = (extent - m - 1) + ns mask_lo = (off_lo >= 0) & (off_lo < extent) & mask tl.store(DR + m * extent + off_lo, da, mask=mask_lo) da = da * scale # convert da # write-back DAs = DA + z * stride_zdx + off_mn tl.store(DAs + lane_n, da, mask=mask) class _softmax(torch.autograd.Function): @staticmethod def make_lut(layout, block, device): _empty = torch.tensor([], dtype=torch.int64, device=layout.device) sizes = _empty.clone() # sizes along rows for h in range(layout.shape[0]): sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) total_sizes = sizes * block # offsets in block format offsets = torch.zeros_like(sizes) offsets[1:] = torch.cumsum(sizes[:-1], dim=0) # block indices columns = layout.nonzero(as_tuple=False)[:, 2] header = torch.stack((sizes, offsets), dim=1).view(-1) lut = torch.cat((header, columns)).type(torch.int32).to(device) return lut, int(total_sizes.max()) @staticmethod def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense): if scale is not None and isinstance(scale, torch.Tensor): assert scale.device.type == "cpu" scale = scale.item() M = a.shape[0] grid = [spdims[0], spdims[1] * block, M] rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride() # enqueue kernel out = torch.empty_like(a) _blocksparse_softmax_fwd[grid]( out, a, a.stride(0), lut, # rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn# scale, # is_causal, # BLOCK_SIZE=block, # ROW_SIZE=next_power_of_2(maxlut), # IS_DENSE=is_dense, # num_warps=num_warps(maxlut) # ) # save to context # ctx.mark_dirty(x) ctx.save_for_backward(out, lut) ctx.spdims = spdims ctx.block = block ctx.maxlut = maxlut ctx.scale = scale ctx.rel_shape = rel_shape ctx.rel_strides = rel_strides ctx.rel_dtype = a.dtype ctx.is_dense = is_dense ctx.is_causal = is_causal return out @staticmethod def backward(ctx, dout): # retrieve from context out, lut = ctx.saved_tensors # relative logits gradients dr = None if ctx.needs_input_grad[3]: dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device) # run kernel M = out.shape[0] grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) da = torch.empty_like(dout) _blocksparse_softmax_bwd[grid]( da, da.stride(0), # dout, dout.stride(0), # out, out.stride(0), # ctx.scale, # lut, # dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], # ctx.is_causal, # BLOCK_SIZE=ctx.block, # ROW_SIZE=next_power_of_2(ctx.maxlut), # IS_DENSE=ctx.is_dense, # num_warps=num_warps(ctx.maxlut) # ) return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None) class softmax: def __init__(self, layout, block, device, is_dense=False): self.spdims = layout.shape self.layout = layout self.block = block self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device) self.is_dense = is_dense def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): if rel_logits is not None and rel_logits.dtype != a.dtype: raise ValueError(f"relative position embedding must be {a.dtype}") a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut, self.is_dense) return a