nki_samples.reference.attention.flash_fwd¶
- nki_samples.reference.attention.flash_fwd = <neuronxcc.nki.compile.GenericKernel object>¶
Flash Attention Forward kernel
- IO tensor layouts:
q: shape (bs, n_heads, d, seq_q)
k: shape (bs, nk_heads, d, seq_k)
v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d)
seed: shape (1,)
logit_bias: shape (bs, n_heads, seq_q, seq_k)
o: shape (bs, n_heads, seq_q, d)
lse: shape (bs, n_heads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None
This kernel requires seq_k == seq_v
- IO tensor dtypes:
This kernel assumes all IO tensors have the same dtype
If mixed_precision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs.
- Compile-time Constants:
softmax_scale: scaling for softmax, is None, default is
1.0/(d**0.5)
mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to
true
, if false, we use same precision as input typescausal_mask: flag to set causal masking
- config: Instance of
nki.kernels.attention.FlashConfig
with Performance config parameters for flash attention with default values seq_tile_size:
default=2048
, size of the kv tile size for attention computation reduction training: bool to indicate training vs inferencedefault=True
- config: Instance of
- Performance Notes:
For better performance, the kernel is tiled to be of size
config.seq_tile_size
, and Flash attention math techniques are applied in unit ofconfig.seq_tile_size
. Seqlen that is not divisible byconfig.seq_tile_size
is not supported at the moment.For large seqlen,
o_buffer
will overflow the statebuf. the kernel is tileo_buffer
based on the value ofconfig.attn_core_tile_size
. This is a tradeoff between memory usage and performance. The default value ofconfig.attn_core_tile_size
is 256, which means theo_buffer
will roughly take half of the statebuf. The computes are also tiled accordingly. DMA will be rematerializedseqlen_q // B_P_SIZE // attn_core_tile_size times
.- GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of nheads
- Example usage:
- MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
usage:
flash_fwd[b, h](q, k, v, ...)
- GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
usage:
flash_fwd[b, kv_h](q, k, v, ...)