nki_samples.reference.allocated_fused_linear.allocated_fused_rms_norm_qkv

nki_samples.reference.allocated_fused_linear.allocated_fused_rms_norm_qkv = <neuronxcc.nki.compile.GenericKernel object>

Allocated kernel that computes RMSNorm(hidden) @ wQKV. This kernel is designed to only handle fp16/bf16 tensor types. Internally, normalizations are cast to fp32 to avoid NaN errors.

Parameters:
  • hidden (_type_) – Input tensor of the attention block in BSH layout

  • weights (_type_) – Fused QKV linear weights, assumed to be eltwise-multiplied with RMS norm weight vector (gamma)

  • out_tensor (_type_) – Output tensor

  • norm_dtype (_type_, optional) – Data type for RMS norm, should be f32 to avoid NaN. Defaults to nl.float32.

  • eps (_type_, optional) – RMS norm epsilon term. Defaults to 1e-6.