nki_samples.reference.attention.flash_fwd

nki_samples.reference.attention.flash_fwd = <neuronxcc.nki.compile.GenericKernel object>

Flash Attention Forward kernel

IO tensor layouts:
  • q: shape (bs, n_heads, d, seq_q)

  • k: shape (bs, nk_heads, d, seq_k)

  • v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d)

  • seed: shape (1,)

  • logit_bias: shape (bs, n_heads, seq_q, seq_k)

  • o: shape (bs, n_heads, seq_q, d)

  • lse: shape (bs, n_heads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None

  • This kernel requires seq_k == seq_v

IO tensor dtypes:
  • This kernel assumes all IO tensors have the same dtype

  • If mixed_precision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs.

Compile-time Constants:
  • softmax_scale: scaling for softmax, is None, default is 1.0/(d**0.5)

  • mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to true, if false, we use same precision as input types

  • causal_mask: flag to set causal masking

  • config: Instance of nki.kernels.attention.FlashConfig with Performance config parameters for flash attention with default values

    seq_tile_size: default=2048, size of the kv tile size for attention computation reduction training: bool to indicate training vs inference default=True

Performance Notes:

For better performance, the kernel is tiled to be of size config.seq_tile_size, and Flash attention math techniques are applied in unit of config.seq_tile_size. Seqlen that is not divisible by config.seq_tile_size is not supported at the moment.

For large seqlen, o_buffer will overflow the statebuf. the kernel is tile o_buffer based on the value of config.attn_core_tile_size. This is a tradeoff between memory usage and performance. The default value of config.attn_core_tile_size is 256, which means the o_buffer will roughly take half of the statebuf. The computes are also tiled accordingly. DMA will be rematerialized seqlen_q // B_P_SIZE // attn_core_tile_size times.

GQA support Notes:

the spmd kernel for launching kernel should be on kv_heads instead of nheads

Example usage:
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]

usage: flash_fwd[b, h](q, k, v, ...)

GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]

usage: flash_fwd[b, kv_h](q, k, v, ...)