nki_samples.reference.attention.flash_attn_bwd

nki_samples.reference.attention.flash_attn_bwd = <neuronxcc.nki.compile.GenericKernel object>

Flash attention backward kernel. Compute the backward gradients.

IO tensor layouts:
  • q_ref: shape (bs, nheads, head_size, seq)

  • k_ref: shape (bs, nheads, head_size, seq)

  • v_ref: shape (bs, nheads, head_size, seq)

  • o_ref: shape (bs, nheads, head_size, seq)

  • dy_ref: shape (bs, nheads, head_size, seq)

  • lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax)

  • seed_ref: shape (1,)

  • logit_bias_ref: shape (bs, n_heads, seq_q, seq_k)

  • out_dq_ref: shape (bs, nheads, head_size, seq)

  • out_dk_ref: shape (bs, nheads, head_size, seq)

  • out_dv_ref: shape (bs, nheads, head_size, seq)

Detailed steps:
  1. D = rowsum(dO ◦ O) (pointwise multiply)

  2. Recompute (softmax(Q^T@K + logic_bias))

2.1 Q^T@K 2.2 Scale the QK score 2.3 Apply causal mask and add logit_bias 2.4 softmax

  1. Compute the gradients of y = score @ V with respect to the loss

  2. Compute the gradients of y = softmax(x)

  3. Compute the gradients of Q^T@K

4.1 Compute dQ 4.2 Compute dK