nki_samples.reference.allocated_fused_linear.allocated_fused_rms_norm_qkv¶
- nki_samples.reference.allocated_fused_linear.allocated_fused_rms_norm_qkv = <neuronxcc.nki.compile.GenericKernel object>¶
Allocated kernel that computes RMSNorm(hidden) @ wQKV. This kernel is designed to only handle fp16/bf16 tensor types. Internally, normalizations are cast to fp32 to avoid NaN errors.
- Parameters:
hidden (_type_) – Input tensor of the attention block in BSH layout
weights (_type_) – Fused QKV linear weights, assumed to be eltwise-multiplied with RMS norm weight vector (gamma)
out_tensor (_type_) – Output tensor
norm_dtype (_type_, optional) – Data type for RMS norm, should be f32 to avoid NaN. Defaults to nl.float32.
eps (_type_, optional) – RMS norm epsilon term. Defaults to 1e-6.