Generate NKI Kernels from NKIPy Kernels

2. Generate NKI Kernels from NKIPy Kernels#

NKIPy is a project under active development to help Neuron users write and execute kernels on Trainium with ease.

This demo focuses on the ability of NKIPy taking in NumPy kernels and generating NKI code.

Let’s first look at a few examples!

2.1. Softmax Examples#

import numpy as np

# Here we have a softmax kernel implemented in NumPy
def softmax_kernel(x):
    exp_x = np.exp(np.subtract(x, np.max(x, axis=-1, keepdims=True)))
    sum_x = np.sum(exp_x, axis=-1, keepdims=True)

    return np.divide(exp_x, sum_x)

Since NKIPy kernels can be just NumPy kernels, they can run as such.

x = np.random.rand(2, 2).astype(np.float32)
print("Input:", x)
out = softmax_kernel(x)
print("Output:", out)
Input: [[0.9579238  0.1510723 ]
 [0.64579445 0.8481816 ]]
Output: [[0.69143814 0.3085618 ]
 [0.44957522 0.55042475]]

To generate NKI code from the NumPy function above, we need to first trace it – in this step, NKIPy will go through the NumPy kernel and convert it to a NKIPy kernel

To trace it, we need to wrap it with trace, then specialize it with concrete shape.

from nkipy.core.trace import NKIPyKernel

softmax_nkipy_kernel = NKIPyKernel.trace(softmax_kernel)

Now the function is traced, it becomes a NKIPy kernel, and we are ready to convert it to NKI

from nkipy.core.compile import lower_to_nki

softmax_nkipy_kernel.specialize(x)
nki_code = lower_to_nki(softmax_nkipy_kernel)
# Add some helper function to display the generated code
from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import HtmlFormatter
from IPython.display import HTML, display

def display_code(code):
    formatter = HtmlFormatter(style='friendly', full=True)
    highlighted_code = highlight(code, PythonLexer(), formatter)
    
    custom_html: str = f"""
    <div>
        {highlighted_code}
    </div>
    """
    display(HTML(custom_html))


display_code(nki_code)

import numpy as np
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt
import neuronxcc.nki.isa as nisa
from neuronxcc.nki import trace
from neuronxcc.nki.language import par_dim

@trace
def sg0000(
  v1,
  v2,
):
  import numpy as np
  import neuronxcc.nki as nki
  import neuronxcc.nki.language as nl
  import neuronxcc.nki.typing as nt
  import neuronxcc.nki.isa as nisa
  from neuronxcc.nki import trace
  from neuronxcc.nki.language import par_dim
 
  v1 = v1
  v2 = v2
 
  v3 = nl.ndarray((nl.par_dim(2), 2), dtype=np.float32, name="op1.103", buffer=nl.sbuf)
  v4 = nl.ndarray((nl.par_dim(2), 2), dtype=np.float32, name="op5.105", buffer=nl.sbuf)
  v5 = nl.ndarray((nl.par_dim(2), 1), dtype=np.float32, name="op1.107", buffer=nl.sbuf)
  v6 = nl.ndarray((nl.par_dim(2), 1), dtype=np.float32, name="op5.109", buffer=nl.sbuf)
  v7 = nl.ndarray((nl.par_dim(2), 2), dtype=np.float32, name="op6.111", buffer=nl.sbuf)
  v8 = nl.ndarray((nl.par_dim(2), 1), dtype=np.float32, name="op8.113", buffer=nl.sbuf)
  v9 = nl.ndarray((nl.par_dim(2), 1), dtype=np.float32, name="op12.115", buffer=nl.sbuf)
  v10 = nl.ndarray((nl.par_dim(2), 2), dtype=np.float32, name="op12.117", buffer=nl.sbuf)
 
  def BB_entry_1():
    v3[nl.arange(2)[:, None], nl.arange(2)[None, :]] = nl.load(v1[nl.arange(2)[:, None], nl.arange(2)[None, :]], dtype=np.float32, mask=None)
    v4[nl.arange(2)[:, None], nl.arange(2)[None, :]] = nl.load(v1[nl.arange(2)[:, None], nl.arange(2)[None, :]], dtype=np.float32, mask=None)
    v5[nl.arange(2)[:, None], 0] = nisa.tensor_reduce(nl.maximum, data=v3[nl.arange(2)[:, None], nl.arange(2)[None, :]], mask=None, axis=[1], dtype=np.float32, negate=False)
    v6[nl.arange(2)[:, None], 0] = nisa.tensor_scalar(data=v5[nl.arange(2)[:, None], 0],  op0=nl.maximum, operand0=-np.inf, reverse0=False, op1=nl.multiply, operand1=np.dtype(np.float32).type(-1.0), reverse1=False, dtype=np.float32, mask=None, engine=nki.isa.engine.unknown)
    v7[nl.arange(2)[:, None], nl.arange(2)[None, :]] = nisa.activation(op=nl.exp, data=v4[nl.arange(2)[:, None], nl.arange(2)[None, :]], bias=v6[nl.arange(2)[:, None], 0], scale=1.0, mask=None, dtype=np.float32)
    v8[nl.arange(2)[:, None], 0] = nisa.tensor_reduce(nl.add, data=v7[nl.arange(2)[:, None], nl.arange(2)[None, :]], mask=None, axis=[1], dtype=np.float32, negate=False)
    v9[nl.arange(2)[:, None], 0] = nisa.reciprocal(data=v8[nl.arange(2)[:, None], 0], mask=None, dtype=np.float32)
    v10[nl.arange(2)[:, None], nl.arange(2)[None, :]] = nisa.tensor_scalar(data=v7[nl.arange(2)[:, None], nl.arange(2)[None, :]], op0=nl.multiply, operand0=v9[nl.arange(2)[:, None], 0], reverse0=False, dtype=np.float32, mask=None, engine=nki.isa.engine.unknown)
    nl.store(v2[nl.arange(2)[:, None], nl.arange(2)[None, :]], value=v10[nl.arange(2)[:, None], nl.arange(2)[None, :]], mask=None)
 
  BB_entry_1()


cu = sg0000.specialize(
  nt.tensor[(2, 2), np.float32], # i=0
  nt.tensor[(2, 2), np.float32], # i=1
)
print(cu)
ir = cu


# nki.simulate_kernel(sg0000, 
  # np.ndarray(shape=(2, 2), dtype=np.float32), # i=0
  # np.ndarray(shape=(2, 2), dtype=np.float32), # i=1
# )

The above NKI code is translated from the optimized intermediate representation of the Neuron Compiler.

There are still some gaps here from a proper hand-written NKI kernel:

  • We recently changed how NKI code returns output tensors. This generated NKI uses the old syntax, which takes the output tensor as an input argument.

  • Some additional structures are created, such as the additional BB_entry_1

  • All shapes are concrete values rather than variables

Let’s try to change the input tensor shape, and do the process again.

x = np.random.rand(256, 256).astype(np.float32)
softmax_nkipy_kernel = NKIPyKernel.trace(softmax_kernel)
_ = softmax_nkipy_kernel.specialize(x)
nki_code = lower_to_nki(softmax_nkipy_kernel)

display_code(nki_code)

import numpy as np
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt
import neuronxcc.nki.isa as nisa
from neuronxcc.nki import trace
from neuronxcc.nki.language import par_dim

@trace
def sg0000(
  v1,
  v2,
):
  import numpy as np
  import neuronxcc.nki as nki
  import neuronxcc.nki.language as nl
  import neuronxcc.nki.typing as nt
  import neuronxcc.nki.isa as nisa
  from neuronxcc.nki import trace
  from neuronxcc.nki.language import par_dim
 
  v1 = v1
  v2 = v2
 
  v3 = nl.ndarray((2, nl.par_dim(128), 256), dtype=np.float32, name="op1.129", buffer=nl.sbuf)
  v4 = nl.ndarray((2, nl.par_dim(128), 1), dtype=np.float32, name="op1.131", buffer=nl.sbuf)
  v5 = nl.ndarray((2, nl.par_dim(128), 1), dtype=np.float32, name="op5.133", buffer=nl.sbuf)
  v6 = nl.ndarray((2, nl.par_dim(128), 256), dtype=np.float32, name="op6.135", buffer=nl.sbuf)
  v7 = nl.ndarray((2, nl.par_dim(128), 1), dtype=np.float32, name="op8.137", buffer=nl.sbuf)
  v8 = nl.ndarray((2, nl.par_dim(128), 1), dtype=np.float32, name="op12.139", buffer=nl.sbuf)
  v9 = nl.ndarray((2, nl.par_dim(128), 256), dtype=np.float32, name="op12.141", buffer=nl.sbuf)
 
  def BB_entry_1():
    for i0 in nl.affine_range(2):
      v3[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]] = nl.load(v1[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]], dtype=np.float32, mask=None)
      v4[i0, nl.arange(128)[:, None], 0] = nisa.tensor_reduce(nl.maximum, data=v3[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]], mask=None, axis=[1], dtype=np.float32, negate=False)
      v5[i0, nl.arange(128)[:, None], 0] = nisa.tensor_scalar(data=v4[i0, nl.arange(128)[:, None], 0],  op0=nl.maximum, operand0=-np.inf, reverse0=False, op1=nl.multiply, operand1=np.dtype(np.float32).type(-1.0), reverse1=False, dtype=np.float32, mask=None, engine=nki.isa.engine.unknown)
      v6[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]] = nisa.activation(op=nl.exp, data=v3[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]], bias=v5[i0, nl.arange(128)[:, None], 0], scale=1.0, mask=None, dtype=np.float32)
      v7[i0, nl.arange(128)[:, None], 0] = nisa.tensor_reduce(nl.add, data=v6[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]], mask=None, axis=[1], dtype=np.float32, negate=False)
      v8[i0, nl.arange(128)[:, None], 0] = nisa.reciprocal(data=v7[i0, nl.arange(128)[:, None], 0], mask=None, dtype=np.float32)
      v9[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]] = nisa.tensor_scalar(data=v6[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]], op0=nl.multiply, operand0=v8[i0, nl.arange(128)[:, None], 0], reverse0=False, dtype=np.float32, mask=None, engine=nki.isa.engine.unknown)
      nl.store(v2[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]], value=v9[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]], mask=None)
      """ end loop i0 """
 
  BB_entry_1()


cu = sg0000.specialize(
  nt.tensor[(2, 128, 256), np.float32], # i=0
  nt.tensor[(2, 128, 256), np.float32], # i=1
)
print(cu)
ir = cu


# nki.simulate_kernel(sg0000, 
  # np.ndarray(shape=(2, 128, 256), dtype=np.float32), # i=0
  # np.ndarray(shape=(2, 128, 256), dtype=np.float32), # i=1
# )

The generated NKI code is certainly different! We see a new loop being introduced for i0 in nl.affine_range(2): – this is because the compiler is doing tiling so we can meet the 128 partition dimension size of the Trainium hardware.

2.2. Matrix Multiplication Examples#

Now let’s move on to something more interesting – matrix multiplication. Trainium hardware is really powerful with it. Let’s see how Neuron Compiler fully utilizes the hardware!

This goes well with the NKI Matrix Multiplication Tutorial.

def matmul_kernel(x, y):
    return np.matmul(x, y)

def gen_matmul_nki(M, N, K):
    x = np.random.rand(M, K).astype(np.float32)
    y = np.random.rand(K, N).astype(np.float32)

    matmul_nkipy_kernel = NKIPyKernel.trace(matmul_kernel)
    _ = matmul_nkipy_kernel.specialize(x, y)
    nki_code = lower_to_nki(matmul_nkipy_kernel)
    return nki_code

display_code(gen_matmul_nki(M=64, N=512, K=128)) # [64, 128] @ [128, 512]

import numpy as np
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt
import neuronxcc.nki.isa as nisa
from neuronxcc.nki import trace
from neuronxcc.nki.language import par_dim

@trace
def sg0000(
  v1,
  v2,
  v3,
):
  import numpy as np
  import neuronxcc.nki as nki
  import neuronxcc.nki.language as nl
  import neuronxcc.nki.typing as nt
  import neuronxcc.nki.isa as nisa
  from neuronxcc.nki import trace
  from neuronxcc.nki.language import par_dim
 
  v1 = v1
  v2 = v2
  v3 = v3
 
  v4 = nl.shared_constant(np.identity(128, dtype=np.float32))
  v5 = nl.ndarray((nl.par_dim(128), 128), dtype=np.float32, name="identity_local_75", buffer=nl.sbuf)
  v6 = nl.ndarray((nl.par_dim(64), 128), dtype=np.float32, name="40.83", buffer=nl.sbuf)
  v7 = nl.zeros((nl.par_dim(128), 64), dtype=np.float32, name="40.71", buffer=nl.psum, lazy_initialization=True)
  v8 = nl.ndarray((nl.par_dim(128), 64), dtype=np.float32, name="40.68", buffer=nl.sbuf)
  v9 = nl.ndarray((2, nl.par_dim(128), 256), dtype=np.float32, name="op0.85", buffer=nl.sbuf)
  v10 = nl.zeros((2, nl.par_dim(64), 256), dtype=np.float32, name="op0.81", buffer=nl.psum, lazy_initialization=True)
  v11 = nl.ndarray((2, nl.par_dim(64), 256), dtype=np.float32, name="", buffer=nl.sbuf)
 
  def BB_entry_1():
    v5[nl.arange(128)[:, None], nl.arange(128)[None, :]] = nl.load(v4[nl.arange(128)[:, None], nl.arange(128)[None, :]], dtype=np.float32, mask=None)
    v6[nl.arange(64)[:, None], nl.arange(128)[None, :]] = nl.load(v1[nl.arange(64)[:, None], nl.arange(128)[None, :]], dtype=np.float32, mask=None)
    v7[nl.arange(128)[:, None], nl.arange(64)[None, :]] = nisa.nc_matmul(v6[nl.arange(64)[:, None], nl.arange(128)[None, :]], v5[nl.arange(64)[:, None], nl.arange(64)[None, :]], is_stationary_onezero=False, is_moving_onezero=True, mask=None, is_transpose=True)
    v8[nl.arange(128)[:, None], nl.arange(64)[None, :]] = nl.copy(v7[nl.arange(128)[:, None], nl.arange(64)[None, :]], dtype=np.float32, mask=None)
   
    for i0 in nl.affine_range(2):
      v9[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]] = nl.load(v2[nl.arange(128)[:, None], i0, nl.arange(256)[None, :]], dtype=np.float32, mask=None)
      v10[i0, nl.arange(64)[:, None], nl.arange(256)[None, :]] = nisa.nc_matmul(v8[nl.arange(128)[:, None], nl.arange(64)[None, :]], v9[i0, nl.arange(128)[:, None], nl.arange(256)[None, :]], is_stationary_onezero=False, is_moving_onezero=False, mask=None)
      v11[i0, nl.arange(64)[:, None], nl.arange(256)[None, :]] = nl.copy(v10[i0, nl.arange(64)[:, None], nl.arange(256)[None, :]], dtype=np.float32, mask=None)
      """ end loop i0 """
   
    for i1 in nl.affine_range(2):
      nl.store(v3[nl.arange(64)[:, None], i1, nl.arange(256)[None, :]], value=v11[i1, nl.arange(64)[:, None], nl.arange(256)[None, :]], mask=None)
      """ end loop i1 """
 
  BB_entry_1()


cu = sg0000.specialize(
  nt.tensor[(64, 128), np.float32], # i=0
  nt.tensor[(128, 2, 256), np.float32], # i=1
  nt.tensor[(64, 2, 256), np.float32], # i=2
)
print(cu)
ir = cu


# nki.simulate_kernel(sg0000, 
  # np.ndarray(shape=(64, 128), dtype=np.float32), # i=0
  # np.ndarray(shape=(128, 2, 256), dtype=np.float32), # i=1
  # np.ndarray(shape=(64, 2, 256), dtype=np.float32), # i=2
# )

Compared to the Basic Compute Kernel in the tutorial, this generated kernel does an additional nc_matmul, which is doing the transpose of x because unlike the tutorial, the lhs is not pre-transposed.

Let’s try with some larger sizes.

display_code(gen_matmul_nki(M=1024, N=1024, K=512)) # [1024, 512] @ [512, 1024]

import numpy as np
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt
import neuronxcc.nki.isa as nisa
from neuronxcc.nki import trace
from neuronxcc.nki.language import par_dim

@trace
def sg0000(
  v1,
  v2,
  v3,
):
  import numpy as np
  import neuronxcc.nki as nki
  import neuronxcc.nki.language as nl
  import neuronxcc.nki.typing as nt
  import neuronxcc.nki.isa as nisa
  from neuronxcc.nki import trace
  from neuronxcc.nki.language import par_dim
 
  v1 = v1
  v2 = v2
  v3 = v3
 
  v4 = nl.shared_constant(np.identity(128, dtype=np.float32))
  v5 = nl.ndarray((nl.par_dim(128), 128), dtype=np.float32, name="identity_local_127", buffer=nl.sbuf)
  v6 = nl.ndarray((2, 4, nl.par_dim(128), 1024), dtype=np.float32, name="y_local_96", buffer=nl.sbuf)
  v7 = nl.ndarray((2, 2, nl.par_dim(128), 1024), dtype=np.float32, name="86.118", buffer=nl.sbuf)
  v8 = nl.zeros((2, 2, 8, nl.par_dim(128), 128), dtype=np.float32, name="86.123", buffer=nl.psum, lazy_initialization=True)
  v9 = nl.ndarray((2, 2, nl.par_dim(128), 8, 128), dtype=np.float32, name="x_pftranspose_86", buffer=nl.sbuf)
  v10 = nl.zeros((2, 2, 2, 2, nl.par_dim(128), 512), dtype=np.float32, name="", buffer=nl.psum, lazy_initialization=True)
  v11 = nl.ndarray((2, 2, 2, nl.par_dim(128), 1024), dtype=np.float32, name="", buffer=nl.sbuf)
 
  def BB_entry_1():
    v5[nl.arange(128)[:, None], nl.arange(128)[None, :]] = nl.load(v4[nl.arange(128)[:, None], nl.arange(128)[None, :]], dtype=np.float32, mask=None)
   
    for i0 in nl.affine_range(2):
      for i1 in nl.affine_range(4):
        v6[i0, i1, nl.arange(128)[:, None], nl.arange(1024)[None, :]] = nl.load(v2[i1, nl.arange(128)[:, None], nl.arange(1024)[None, :]], dtype=np.float32, mask=None)
        """ end loop i1 """
     
      for i2 in nl.affine_range(2):
        v7[i0, i2, nl.arange(128)[:, None, None], 128*nl.arange(8)[None, :, None]+nl.arange(128)[None, None, :]] = nl.load(v1[i0, nl.arange(128)[:, None, None], 8*i2+nl.arange(8)[None, :, None], nl.arange(128)[None, None, :]], dtype=np.float32, mask=None)
       
        for i3 in nl.affine_range(8):
          v8[i0, i2, i3, nl.arange(128)[:, None], nl.arange(128)[None, :]] = nisa.nc_matmul(v7[i0, i2, nl.arange(128)[:, None], 128*i3+nl.arange(128)[None, :]], v5[nl.arange(128)[:, None], nl.arange(128)[None, :]], is_stationary_onezero=False, is_moving_onezero=True, mask=None, is_transpose=True)
          v9[i0, i2, nl.arange(128)[:, None], i3, nl.arange(128)[None, :]] = nl.copy(v8[i0, i2, i3, nl.arange(128)[:, None], nl.arange(128)[None, :]], dtype=np.float32, mask=None)
          """ end loop i3 """
       
        for i4 in nl.affine_range(2):
          for i5 in nl.affine_range(2):
            for i6 in nl.affine_range(4):
              v10[i0, i2, i4, i5, nl.arange(128)[:, None], nl.arange(512)[None, :]] += nisa.nc_matmul(v9[i0, i2, nl.arange(128)[:, None], i6+4*i4, nl.arange(128)[None, :]], v6[i0, i6, nl.arange(128)[:, None], 512*i5+nl.arange(512)[None, :]], is_stationary_onezero=False, is_moving_onezero=False, mask=None)
              """ end loop i6 """
            v11[i0, i2, i4, nl.arange(128)[:, None], 512*i5+nl.arange(512)[None, :]] = nl.copy(v10[i0, i2, i4, i5, nl.arange(128)[:, None], nl.arange(512)[None, :]], dtype=np.float32, mask=None)
            """ end loop i5 """
          nl.store(v3[i0, nl.arange(128)[:, None], i4+2*i2, nl.arange(1024)[None, :]], value=v11[i0, i2, i4, nl.arange(128)[:, None], nl.arange(1024)[None, :]], mask=None)
          """ end loop i4 """
        """ end loop i2 """
      """ end loop i0 """
 
  BB_entry_1()


cu = sg0000.specialize(
  nt.tensor[(2, 128, 16, 128), np.float32], # i=0
  nt.tensor[(4, 128, 1024), np.float32], # i=1
  nt.tensor[(2, 128, 4, 1024), np.float32], # i=2
)
print(cu)
ir = cu


# nki.simulate_kernel(sg0000, 
  # np.ndarray(shape=(2, 128, 16, 128), dtype=np.float32), # i=0
  # np.ndarray(shape=(4, 128, 1024), dtype=np.float32), # i=1
  # np.ndarray(shape=(2, 128, 4, 1024), dtype=np.float32), # i=2
# )

Now we are seeing more tiling structure.

Let’s try something even larger to see how it gets handled!

display_code(gen_matmul_nki(M=4096, N=4096, K=2048)) # [4096, 2048] @ [2048, 4096]

import numpy as np
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt
import neuronxcc.nki.isa as nisa
from neuronxcc.nki import trace
from neuronxcc.nki.language import par_dim

@trace
def sg0000(
  v1,
  v2,
  v3,
):
  import numpy as np
  import neuronxcc.nki as nki
  import neuronxcc.nki.language as nl
  import neuronxcc.nki.typing as nt
  import neuronxcc.nki.isa as nisa
  from neuronxcc.nki import trace
  from neuronxcc.nki.language import par_dim
 
  v1 = v1
  v2 = v2
  v3 = v3
 
  v4 = nl.shared_constant(np.identity(128, dtype=np.float32))
  v5 = nl.ndarray((nl.par_dim(128), 128), dtype=np.float32, name="identity_local_130", buffer=nl.sbuf)
  v6 = nl.ndarray((2, 4, 2, 8, nl.par_dim(128), 1024), dtype=np.float32, name="y_local_97", buffer=nl.sbuf)
  v7 = nl.ndarray((4, 2, 16, 2, nl.par_dim(128), 1024), dtype=np.float32, name="", buffer=nl.sbuf)
  v8 = nl.zeros((2, 4, 16, 2, 8, nl.par_dim(128), 128), dtype=np.float32, name="86.126", buffer=nl.psum, lazy_initialization=True)
  v9 = nl.ndarray((4, 2, 16, 2, nl.par_dim(128), 8, 128), dtype=np.float32, name="", buffer=nl.sbuf)
  v10 = nl.zeros((2, 4, 16, 2, nl.par_dim(128), 512), dtype=np.float32, name="", buffer=nl.psum, lazy_initialization=True)
  v11 = nl.ndarray((2, 16, 4, nl.par_dim(128), 1024), dtype=np.float32, name="", buffer=nl.sbuf)
 
  def BB_entry_1():
    v5[nl.arange(128)[:, None], nl.arange(128)[None, :]] = nl.load(v4[nl.arange(128)[:, None], nl.arange(128)[None, :]], dtype=np.float32, mask=None)
   
    for i0 in nl.affine_range(2):
      for i1 in nl.affine_range(4):
        for i2 in nl.affine_range(2):
          for i3 in nl.affine_range(8):
            v6[i0, i1, i2, i3, nl.arange(128)[:, None], nl.arange(1024)[None, :]] = nl.load(v2[i3+8*i2, nl.arange(128)[:, None], 1024*i1+nl.arange(1024)[None, :]], dtype=np.float32, mask=None)
            """ end loop i3 """
          """ end loop i2 """
       
        for i4 in nl.affine_range(16):
          for i5 in nl.affine_range(2):
            v7[i1, i0, i4, i5, nl.arange(128)[:, None, None], 128*nl.arange(8)[None, :, None]+nl.arange(128)[None, None, :]] = nl.load(v1[i0, i4, nl.arange(128)[:, None, None], 8*i5+nl.arange(8)[None, :, None], nl.arange(128)[None, None, :]], dtype=np.float32, mask=None)
           
            for i6 in nl.affine_range(8):
              v8[i0, i1, i4, i5, i6, nl.arange(128)[:, None], nl.arange(128)[None, :]] = nisa.nc_matmul(v7[i1, i0, i4, i5, nl.arange(128)[:, None], 128*i6+nl.arange(128)[None, :]], v5[nl.arange(128)[:, None], nl.arange(128)[None, :]], is_stationary_onezero=False, is_moving_onezero=True, mask=None, is_transpose=True)
              v9[i1, i0, i4, i5, nl.arange(128)[:, None], i6, nl.arange(128)[None, :]] = nl.copy(v8[i0, i1, i4, i5, i6, nl.arange(128)[:, None], nl.arange(128)[None, :]], dtype=np.float32, mask=None)
              """ end loop i6 """
            """ end loop i5 """
         
          for i7 in nl.affine_range(2):
            for i8 in nl.affine_range(2):
              for i9 in nl.affine_range(8):
                v10[i0, i1, i4, i7, nl.arange(128)[:, None], nl.arange(512)[None, :]] += nisa.nc_matmul(v9[i1, i0, i4, i8, nl.arange(128)[:, None], i9, nl.arange(128)[None, :]], v6[i0, i1, i8, i9, nl.arange(128)[:, None], 512*i7+nl.arange(512)[None, :]], is_stationary_onezero=False, is_moving_onezero=False, mask=None)
                """ end loop i9 """
              """ end loop i8 """
            v11[i0, i4, i1, nl.arange(128)[:, None], 512*i7+nl.arange(512)[None, :]] = nl.copy(v10[i0, i1, i4, i7, nl.arange(128)[:, None], nl.arange(512)[None, :]], dtype=np.float32, mask=None)
            """ end loop i7 """
          nl.store(v3[i0, i4, nl.arange(128)[:, None], 1024*i1+nl.arange(1024)[None, :]], value=v11[i0, i4, i1, nl.arange(128)[:, None], nl.arange(1024)[None, :]], mask=None)
          """ end loop i4 """
        """ end loop i1 """
      """ end loop i0 """
 
  BB_entry_1()


cu = sg0000.specialize(
  nt.tensor[(2, 16, 128, 16, 128), np.float32], # i=0
  nt.tensor[(16, 128, 4096), np.float32], # i=1
  nt.tensor[(2, 16, 128, 4096), np.float32], # i=2
)
print(cu)
ir = cu


# nki.simulate_kernel(sg0000, 
  # np.ndarray(shape=(2, 16, 128, 16, 128), dtype=np.float32), # i=0
  # np.ndarray(shape=(16, 128, 4096), dtype=np.float32), # i=1
  # np.ndarray(shape=(2, 16, 128, 4096), dtype=np.float32), # i=2
# )

The generated NKI code for matrix multiplication with the large size is very similar to the Optimized Matmul Code in the tutorial.

Have fun experimenting with NKIPy now!