nki_samples.reference.attention.flash_attn_bwd¶
- nki_samples.reference.attention.flash_attn_bwd = <neuronxcc.nki.compile.GenericKernel object>¶
Flash attention backward kernel. Compute the backward gradients.
- IO tensor layouts:
q_ref: shape (bs, nheads, head_size, seq)
k_ref: shape (bs, nheads, head_size, seq)
v_ref: shape (bs, nheads, head_size, seq)
o_ref: shape (bs, nheads, head_size, seq)
dy_ref: shape (bs, nheads, head_size, seq)
lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax)
seed_ref: shape (1,)
logit_bias_ref: shape (bs, n_heads, seq_q, seq_k)
out_dq_ref: shape (bs, nheads, head_size, seq)
out_dk_ref: shape (bs, nheads, head_size, seq)
out_dv_ref: shape (bs, nheads, head_size, seq)
- Detailed steps:
D = rowsum(dO ◦ O) (pointwise multiply)
Recompute (softmax(Q^T@K + logic_bias))
2.1 Q^T@K 2.2 Scale the QK score 2.3 Apply causal mask and add logit_bias 2.4 softmax
Compute the gradients of y = score @ V with respect to the loss
Compute the gradients of y = softmax(x)
Compute the gradients of Q^T@K
4.1 Compute dQ 4.2 Compute dK