Primitive Tensor Function can be seen as a self-contained unit responsible for getting the input data, applying the execution, and outputting the expected result. Linear, Add, Relu can be seen as a primitive tensor function, fused function like linear_add or add_relu can also be seen as a primitive tensor function.
Primitive Tensor Function does not restrict its implementation, take add as an example, we can call torch.add or write a vanilla add using python, or even write a parallelized add with the aid of OpenMP.
Example of TPA
The typical Tensor Program Abstraction contains several parts: buffers, loop nests, and computation statement.
from tvm.script import tir as T@T.prim_funcdefmain(A: T.Buffer[128,"float32"],# (Multi-dimensional) buffers thatB: T.Buffer[128,"float32"],# holds the input, output, andC: T.Buffer[128,"float32"]): # intermediate results.for i inrange(128):# Loop nests that drive compute iterations.with T.block("C"):# Blocks can be retrieved and iterated in IRModules. vi = T.axis.spatial(128, i)# Extra information about iteration. (Spatial or Reduction) C[vi]= A[vi]+ B[vi]# Computations statements.
import tvmfrom tvm.script import tir as T@tvm.script.ir_moduleclassMyModule:@T.prim_funcdefmain(A: T.Buffer[128,"float32"],# (Multi-dimensional) buffers thatB: T.Buffer[128,"float32"],# holds the input, output, andC: T.Buffer[128,"float32"]): # intermediate results.for i inrange(128):# Loop nests that drive compute iterations.with T.block("C"):# Blocks can be retrieved and iterated in IRModules. vi = T.axis.spatial(128, i)# Extra information about iteration. (Spatial or Reduction) C[vi]= A[vi]+ B[vi]# Computations statements.
IRModule has attribute func::script to output the decorated function.
print(MyModule.script())-------->@tvm.script.ir_moduleclassModule:@tir.prim_funcdefmain(A: tir.Buffer[128,"float32"],B: tir.Buffer[128,"float32"],C: tir.Buffer[128,"float32"]) ->None:# body# with tir.block("root")for i in tir.serial(128):with tir.block("C"): vi = tir.axis.spatial(128, i) tir.reads(A[vi], B[vi]) tir.writes(C[vi]) C[vi]= A[vi]+ B[vi]
Staying in IRModule is not enough, if you want any transformation on the original code, initiate a class::Schedule to apply transformations.
For example, we can get the attributes of the code from sch using the func::get_xxx
block_c = sch.get_block("C")# get the annotated block, the second parameter is the primitive function name with default value "main"# sch.get_loops returns a list of loop, using i, will automatically unpack the listi,= sch.get_loops(block_c)# getting the loops of a block is getting the outside loops of the blockprint(i)-------->tir.LoopRV(0x1d353c0)
When we get the loops, we can do many fancy things, e.g. unroll loop, vectorize, parallelize...
i0, i1, i2 = sch.split(i, factors=[None, 4, 4])# None is a placeholder which will get the proper numberprint(sch.mod.script())-------->@tvm.script.ir_moduleclassModule:@tir.prim_funcdefmain(A: tir.Buffer[128,"float32"],B: tir.Buffer[128,"float32"],C: tir.Buffer[128,"float32"]) ->None:# body# with tir.block("root")for i_0, i_1, i_2 in tir.grid(8, 4, 4):with tir.block("C"): vi = tir.axis.spatial(128, i_0 *16+ i_1 *4+ i_2) tir.reads(A[vi], B[vi]) tir.writes(C[vi]) C[vi]= A[vi]+ B[vi]
As we can see, the schedule holds a copy of IRModule with the name of mod. Every transformation will result in the change of mod.script().
There can be more transformations:
sch.reorder(i2, i1, i0)sch.parallel(i0)sch.vectorize(i1)print(sch.mod.script())-------->@tvm.script.ir_moduleclassModule:@tir.prim_funcdefmain(A: tir.Buffer[128,"float32"],B: tir.Buffer[128,"float32"],C: tir.Buffer[128,"float32"]) ->None:# body# with tir.block("root")for i_2 in tir.serial(4):for i_1 in tir.vectorized(4):for i_0 in tir.parallel(8):with tir.block("C"): vi = tir.axis.spatial(128, i_0 *16+ i_1 *4+ i_2) tir.reads(A[vi], B[vi]) tir.writes(C[vi]) C[vi]= A[vi]+ B[vi]
From the example above, we have a little comprehension of it:
The computation part is bounded inside a block, like C in the example.
Every attempt of optimization on the original script needs a concrete schedule to implement. (I can not find the reverse operation till now)
build and run
After we make some fancy change, we can export it to a runnable function using tvm:
rt_mod = tvm.build(sch.mod, target="llvm")# build with a supported target as the backendfunc = rt_mod["main"]# build all prim functions inside an IRModule and retrieve by mapping **global symbol** nametype(func)# tvm.runtime.packed_func.PackedFunc# Now the func is a callable functiona = tvm.nd.array(np.arrange(128, dtype="float32"))# to run in tvm PackedFunc, we need tvm specified datab = tvm.nd.array(np.ones(128, dtype="float32"))c = tvm.nd.empty([128], dtype="float32")# allocate an empty array to store the result# Call func as usualfunc(a, b, c)
To evaluate the performance of optimization, tvm provides a evaluator:
func_timer = rt_mod.time_evaluator("main", tvm.cpu())print("Time cost of MyModule is %g sec"%func_timer(a, b, c).mean())
Case Study on MM_Relu
Low-level numpy implementation:
deflnumpy_mm_relu(A: np.ndarray,B: np.ndarray,C: np.ndarray): Y = np.empty((128, 128), dtype="float32")for i inrange(128):for j inrange(128):for k inrange(128):if k ==0: Y[i, j]=0 Y[i, j]= Y[i, j]+ A[i, k]* B[k, j]for i inrange(128):for j inrange(128): C[i, j]=max(Y[i, j], 0)
# `@tvm.script.ir_module` indicate that MyModule is an IRModule. # IRModule is the container object to hold a collection of tensor functions in machine learning compilation. @tvm.script.ir_moduleclassMyModule:@T.prim_funcdefmm_relu(A: T.Buffer[(128,128),"float32"],B: T.Buffer[(128,128),"float32"],C: T.Buffer[(128,128),"float32"]):# `global_symbol` corresponds to the name of the function.# `tir.noalias` is an attribute indicating that all the buffer memories do not overlap. T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True}) Y = T.alloc_buffer((128, 128), dtype="float32")for i, j, k in T.grid(128, 128, 128):with T.block("Y"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) vk = T.axis.reduce(128, k)with T.init(): Y[vi, vj]= T.float32(0) Y[vi, vj]= Y[vi, vj]+ A[vi, vk]* B[vk, vj]for i, j in T.grid(128, 128):with T.block("C"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) C[vi, vj]= T.max(Y[vi, vj], T.float32(0))
The differences between low-level numpy and IRModule are:
variable type: tvm treats all variables as buffer, no matter parameters or temporary variables.
for loop: T.grid is a sugar indicating nested loops.
block: a block is a basic unit of computation in TensorIR, the computation parts are wrapped in annotated blocks like "Y" and "C", they can be retrieved by sch.get_block.
extra information: the more information we provide to compiler, the more the compiler can do for us. Like the T.axis.spatial and T.axis.reduce respectively indicate sequence-irrelevant and reduction required.
Block Axis properties
for i, j, k in T.grid(128, 128, 128):with T.block("Y"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) vk = T.axis.reduce(128, k)with T.init(): Y[vi, vj]= T.float32(0) Y[vi, vj]= Y[vi, vj]+ A[vi, vk]* B[vk, vj]
Notably, for a fixed value of vi and vj, the computation block produces a point value at a spatial location of Y (Y[vi, vj]) that is independent from other locations in Y (with a different vi, vj values). we can call vi, vj spatial axes as they directly corresponds to the beginning of a spatial region of buffers that the block writes to. The axes that involves in reduction (vk) are named as reduce axes.
The extra information of axis can help us validate the correctness. For example, if the axis requires an iterator of 128 but binds to a for loop of range(127), it will raise an exception. Besides, extra information can let the compiler make more fancy things according to their dependency and relevance.
Sugars
The initialization of axises can be simplified as:
# SSR means the properties of each axes are "spatial", "spatial", "reduce"vi, vj, vk = T.axis.remap("SSR", [i, j, k])
global_symbol: the name of function that is unique in this IRModule, will be used when retrieving function inside built Module.
tir.noalias: indicating that all the buffer memories do not overlap.
Transformation
@tvm.script.ir_moduleclassMyModuleWithAxisRemapSugar:@T.prim_funcdefmm_relu(A: T.Buffer[(128,128),"float32"],B: T.Buffer[(128,128),"float32"],C: T.Buffer[(128,128),"float32"]): T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True}) Y = T.alloc_buffer((128, 128), dtype="float32")for i, j, k in T.grid(128, 128, 128):with T.block("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k])with T.init(): Y[vi, vj]= T.float32(0) Y[vi, vj]= Y[vi, vj]+ A[vi, vk]* B[vk, vj]for i, j in T.grid(128, 128):with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj]= T.max(Y[vi, vj], T.float32(0))
If we apply the transformation of fully utilizing the cache of matrix B like this in low-level numpy:
deflnumpy_mm_relu_v2(A: np.ndarray,B: np.ndarray,C: np.ndarray): Y = np.empty((128, 128), dtype="float32")for i inrange(128):for j0 inrange(32):for k inrange(128):for j1 inrange(4):# Here compute 4 continuous number in matrix B/Y at one time j = j0 *4+ j1if k ==0: Y[i, j]=0 Y[i, j]= Y[i, j]+ A[i, k]* B[k, j]for i inrange(128):for j inrange(128): C[i, j]=max(Y[i, j], 0)
In this case, we need to:
split the j loop into two parts
reorder the j loop with k loop
# Initialize the schedulesch = tvm.tir.Schedule(MyModuleWithAxisRemapSugar)# get the blockblock_Y = sch.get_block("Y", func_name="mm_relu")# get the outside loops of Yi, j, k,= sch.get_loops(block_Y)# split j into two partsj0, j1 = sch.split(j, factors=[None, 4])# reordersch.reorder(j0, k, j1)IPython.display.Code(sch.mod.script(), language="python")-------->@tvm.script.ir_moduleclassModule:@tir.prim_funcdefmm_relu(A: tir.Buffer[(128,128),"float32"],B: tir.Buffer[(128,128),"float32"],C: tir.Buffer[(128,128),"float32"]) ->None:# function attr dict tir.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})# body# with tir.block("root") Y = tir.alloc_buffer([128, 128], dtype="float32")for i, j_0, k, j_1 in tir.grid(128, 32, 128, 4):with tir.block("Y"): vi = tir.axis.spatial(128, i) vj = tir.axis.spatial(128, j_0 *4+ j_1) vk = tir.axis.reduce(128, k) tir.reads(A[vi, vk], B[vk, vj]) tir.writes(Y[vi, vj])with tir.init(): Y[vi, vj]= tir.float32(0) Y[vi, vj]= Y[vi, vj]+ A[vi, vk]* B[vk, vj]for i, j in tir.grid(128, 128):with tir.block("C"): vi, vj = tir.axis.remap("SS", [i, j]) tir.reads(Y[vi, vj]) tir.writes(C[vi, vj]) C[vi, vj]= tir.max(Y[vi, vj], tir.float32(0))
What's more, we can find that the block C and block Y shares part of the loop, therefore we can combine them together.
Here te.compute takes the signature te.compute(output_shape, fcompute). And the fcompute function (the lambda function in the example) describes how we want to compute the value of each element Y[i, j] for a given index.
In this particular case, we want to create a function with two input parameters (A, B) and one output parameter (C).