# Tensor Program Abstraction

Credit to: [mlc.ai by Tianqi Chen](https://mlc.ai/chapter_tensor_program/index.html)

Before all, there is a way to print script in pretty way:

```py
import IPython

IPython.display.Code(MyModule.script(), language="python")
```

## Primitive Tensor Function

![primitive tensor function](https://mlc.ai/_images/primitive_tensor_func.png)

**Primitive Tensor Function** can be seen as a self-contained unit responsible for getting the input data, applying the execution, and outputting the expected result. `Linear`, `Add`, `Relu` can be seen as a primitive tensor function, fused function like `linear_add` or `add_relu` can also be seen as a primitive tensor function.

**Primitive Tensor Function** does not restrict its implementation, take `add` as an example, we can call `torch.add` or write a vanilla add using python, or even write a parallelized add with the aid of OpenMP.

## Example of TPA

The typical Tensor Program Abstraction contains several parts: buffers, loop nests, and computation statement.

```py
from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer[128, "float32"],     # (Multi-dimensional) buffers that
         B: T.Buffer[128, "float32"],     # holds the input, output, and
         C: T.Buffer[128, "float32"]):    # intermediate results.

    for i in range(128):                  # Loop nests that drive compute iterations.
        with T.block("C"):                # Blocks can be retrieved and iterated in IRModules.
            vi = T.axis.spatial(128, i)   # Extra information about iteration. (Spatial or Reduction)
            C[vi] = A[vi] + B[vi]         # Computations statements.
```

## Essential classes

### [tvm.tir.Schedule](https://tvm.apache.org/docs/reference/api/python/tir.html?highlight=mod#tvm.tir.Schedule)

> A schedule is a set of transformations that change the order of computation but preserve the semantics of computation.

### [tvm.tir.transform](https://tvm.apache.org/docs/reference/api/python/tir.html?highlight=mod#module-tvm.tir.transform)

## TensorIR

### Do2Learn

```py
import tvm
from tvm.script import tir as T

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(A: T.Buffer[128, "float32"],     # (Multi-dimensional) buffers that
             B: T.Buffer[128, "float32"],     # holds the input, output, and
             C: T.Buffer[128, "float32"]):    # intermediate results.

        for i in range(128):                  # Loop nests that drive compute iterations.
            with T.block("C"):                # Blocks can be retrieved and iterated in IRModules.
                vi = T.axis.spatial(128, i)   # Extra information about iteration. (Spatial or Reduction)
                C[vi] = A[vi] + B[vi]         # Computations statements.
```

IRModule has attribute `func::script` to output the decorated function.

```py
print(MyModule.script())
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
        # body
        # with tir.block("root")
        for i in tir.serial(128):
            with tir.block("C"):
                vi = tir.axis.spatial(128, i)
                tir.reads(A[vi], B[vi])
                tir.writes(C[vi])
                C[vi] = A[vi] + B[vi]
```

Staying in IRModule is not enough, if you want any transformation on the original code, initiate a `class::Schedule` to apply transformations.

```py
sch = tvm.tir.Schedule(MyModule)
type(sch)
-------->
tvm.tir.schedule.schedule.Schedule
```

With Schedule, code can be equipped with optimizations and can run.

```py
print(list(filter(lambda x: not x.startswith("_"), dir(sch))))
-------->
['add_unit_loop', 'annotate', 'bind', 'blockize', 'cache_read', 'cache_write', 'compute_at', 'compute_inline', 'copy', 'decompose_reduction', 'enter_postproc', 'fork_seed', 'fuse', 'get', 'get_block', 'get_child_blocks', 'get_consumers', 'get_loops', 'get_producers', 'get_sref', 'handle', 'mod', 'parallel', 'reindex', 'remove_rv', 'reorder', 'reverse_compute_at', 'reverse_compute_inline', 'rfactor', 'same_as', 'sample_categorical', 'sample_compute_location', 'sample_perfect_tile', 'seed', 'set_axis_separator', 'set_scope', 'show', 'split', 'state', 'storage_align', 'tensorize', 'trace', 'transform_block_layout', 'transform_layout', 'unannotate', 'unroll', 'vectorize']
```

For example, we can get the attributes of the code from sch using the `func::get_xxx`

```py
block_c = sch.get_block("C")        # get the annotated block, the second parameter is the primitive function name with default value "main"
# sch.get_loops returns a list of loop, using i, will automatically unpack the list
i, = sch.get_loops(block_c)         # getting the loops of a block is getting the outside loops of the block
print(i)
-------->
tir.LoopRV(0x1d353c0)
```

When we get the loops, we can do many fancy things, e.g. unroll loop, vectorize, parallelize...

```py
i0, i1, i2 = sch.split(i, factors=[None, 4, 4])   # None is a placeholder which will get the proper number
print(sch.mod.script())
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
        # body
        # with tir.block("root")
        for i_0, i_1, i_2 in tir.grid(8, 4, 4):
            with tir.block("C"):
                vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
                tir.reads(A[vi], B[vi])
                tir.writes(C[vi])
                C[vi] = A[vi] + B[vi]
```

As we can see, the schedule holds a copy of IRModule with the name of `mod`. Every transformation will result in the change of `mod.script()`.

There can be more transformations:

```py
sch.reorder(i2, i1, i0)
sch.parallel(i0)
sch.vectorize(i1)
print(sch.mod.script())
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
        # body
        # with tir.block("root")
        for i_2 in tir.serial(4):
            for i_1 in tir.vectorized(4):
                for i_0 in tir.parallel(8):
                    with tir.block("C"):
                        vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
                        tir.reads(A[vi], B[vi])
                        tir.writes(C[vi])
                        C[vi] = A[vi] + B[vi]
```

From the example above, we have a little comprehension of it:

1. The computation part is bounded inside a block, like `C` in the example.
2. Every attempt of optimization on the original script needs a concrete schedule to implement. (I can not find the reverse operation till now)

#### build and run

After we make some fancy change, we can export it to a runnable function using tvm:

```py
rt_mod = tvm.build(sch.mod, target="llvm")         # build with a supported target as the backend
func = rt_mod["main"]                              # build all prim functions inside an IRModule and retrieve by mapping **global symbol** name
type(func)                                         # tvm.runtime.packed_func.PackedFunc
# Now the func is a callable function
a = tvm.nd.array(np.arrange(128, dtype="float32")) # to run in tvm PackedFunc, we need tvm specified data
b = tvm.nd.array(np.ones(128, dtype="float32"))
c = tvm.nd.empty([128], dtype="float32")           # allocate an empty array to store the result
# Call func as usual
func(a, b, c)
```

To evaluate the performance of optimization, tvm provides a evaluator:

```py
func_timer = rt_mod.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule is %g sec" % func_timer(a, b, c).mean())
```

## Case Study on MM\_Relu

Low-level numpy implementation:

```py
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)
```

```py
# `@tvm.script.ir_module` indicate that MyModule is an IRModule. 
# IRModule is the container object to hold a collection of tensor functions in machine learning compilation. 
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer[(128, 128), "float32"],
                B: T.Buffer[(128, 128), "float32"],
                C: T.Buffer[(128, 128), "float32"]):
        # `global_symbol` corresponds to the name of the function.
        # `tir.noalias` is an attribute indicating that all the buffer memories do not overlap.
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
```

The differences between low-level numpy and IRModule are:

1. variable type: tvm treats all variables as buffer, no matter parameters or temporary variables.
2. for loop: `T.grid` is a sugar indicating nested loops.
3. block: a block is a basic unit of computation in TensorIR, the computation parts are wrapped in annotated blocks like "Y" and "C", they can be retrieved by `sch.get_block`.
4. extra information: the more information we provide to compiler, the more the compiler can do for us. Like the `T.axis.spatial` and `T.axis.reduce` respectively indicate sequence-irrelevant and reduction required.

### Block Axis properties

```py
for i, j, k in T.grid(128, 128, 128):
    with T.block("Y"):
        vi = T.axis.spatial(128, i)
        vj = T.axis.spatial(128, j)
        vk = T.axis.reduce(128, k)
        with T.init():
            Y[vi, vj] = T.float32(0)
        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
```

Notably, for a fixed value of vi and vj, the computation block produces a point value at a spatial location of Y (Y\[vi, vj]) **that is independent from other locations in Y (with a different vi, vj values)**. we can call vi, vj **spatial axes** as they directly corresponds to the beginning of a spatial region of buffers that the block writes to. The axes that involves in reduction (vk) are named as **reduce axes**.

The extra information of axis can help us validate the correctness. For example, if the axis requires an iterator of 128 but binds to a for loop of range(127), it will raise an exception. Besides, extra information can let the compiler make more fancy things according to their dependency and relevance.

#### Sugars

The initialization of axises can be simplified as:

```py
# SSR means the properties of each axes are "spatial", "spatial", "reduce"
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
```

### Function Annotations

```py
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
```

Here are two attributes:

* `global_symbol`: the name of function that is unique in this IRModule, will be used when retrieving function inside built Module.
* `tir.noalias`: indicating that all the buffer memories do not overlap.

## Transformation

```py
@tvm.script.ir_module
class MyModuleWithAxisRemapSugar:
    @T.prim_func
    def mm_relu(A: T.Buffer[(128, 128), "float32"],
                B: T.Buffer[(128, 128), "float32"],
                C: T.Buffer[(128, 128), "float32"]):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
```

If we apply the transformation of fully utilizing the cache of matrix B like this in low-level numpy:

```py
def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j0 in range(32):
            for k in range(128):
                for j1 in range(4):                             # Here compute 4 continuous number in matrix B/Y at one time
                    j = j0 * 4 + j1
                    if k == 0:
                        Y[i, j] = 0
                    Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)
```

Why cache is utilized? See this pic: ![tensor func loop order](https://mlc.ai/_images/tensor_func_loop_order.png)

In this case, we need to:

1. split the j loop into two parts
2. reorder the j loop with k loop

```py
# Initialize the schedule
sch = tvm.tir.Schedule(MyModuleWithAxisRemapSugar)

# get the block
block_Y = sch.get_block("Y", func_name="mm_relu")
# get the outside loops of Y
i, j, k, = sch.get_loops(block_Y)
# split j into two parts
j0, j1 = sch.split(j, factors=[None, 4])
# reorder
sch.reorder(j0, k, j1)

IPython.display.Code(sch.mod.script(), language="python")
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def mm_relu(A: tir.Buffer[(128, 128), "float32"], B: tir.Buffer[(128, 128), "float32"], C: tir.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([128, 128], dtype="float32")
        for i, j_0, k, j_1 in tir.grid(128, 32, 128, 4):
            with tir.block("Y"):
                vi = tir.axis.spatial(128, i)
                vj = tir.axis.spatial(128, j_0 * 4 + j_1)
                vk = tir.axis.reduce(128, k)
                tir.reads(A[vi, vk], B[vk, vj])
                tir.writes(Y[vi, vj])
                with tir.init():
                    Y[vi, vj] = tir.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in tir.grid(128, 128):
            with tir.block("C"):
                vi, vj = tir.axis.remap("SS", [i, j])
                tir.reads(Y[vi, vj])
                tir.writes(C[vi, vj])
                C[vi, vj] = tir.max(Y[vi, vj], tir.float32(0))
```

What's more, we can find that the block C and block Y shares part of the loop, therefore we can combine them together.

```py
block_C = sch.get_block("C", func_name="mm_relu")
sch.reverse_compute_at(block_C, j0)

IPython.display.Code(sch.mod.script(), language="python")
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def mm_relu(A: tir.Buffer[(128, 128), "float32"], B: tir.Buffer[(128, 128), "float32"], C: tir.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([128, 128], dtype="float32")
        for i, j_0 in tir.grid(128, 32):
            for k, j_1 in tir.grid(128, 4):
                with tir.block("Y"):
                    vi = tir.axis.spatial(128, i)
                    vj = tir.axis.spatial(128, j_0 * 4 + j_1)
                    vk = tir.axis.reduce(128, k)
                    tir.reads(A[vi, vk], B[vk, vj])
                    tir.writes(Y[vi, vj])
                    with tir.init():
                        Y[vi, vj] = tir.float32(0)
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in tir.serial(4):
                with tir.block("C"):
                    vi = tir.axis.spatial(128, i)
                    vj = tir.axis.spatial(128, j_0 * 4 + ax0)
                    tir.reads(Y[vi, vj])
                    tir.writes(C[vi, vj])
                    C[vi, vj] = tir.max(Y[vi, vj], tir.float32(0))
```

Last, we can separate the initialization part and the computation part of Y using `func::decompose_reduction`:

```py
sch.decompose_reduction(block_Y, k)
IPython.display.Code(sch.mod.script(), language="python")
------->

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def mm_relu(A: tir.Buffer[(128, 128), "float32"], B: tir.Buffer[(128, 128), "float32"], C: tir.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([128, 128], dtype="float32")
        for i, j_0 in tir.grid(128, 32):
            for j_1_init in tir.serial(4):
                with tir.block("Y_init"):
                    vi = tir.axis.spatial(128, i)
                    vj = tir.axis.spatial(128, j_0 * 4 + j_1_init)
                    tir.reads()
                    tir.writes(Y[vi, vj])
                    Y[vi, vj] = tir.float32(0)
            for k, j_1 in tir.grid(128, 4):
                with tir.block("Y_update"):
                    vi = tir.axis.spatial(128, i)
                    vj = tir.axis.spatial(128, j_0 * 4 + j_1)
                    vk = tir.axis.reduce(128, k)
                    tir.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
                    tir.writes(Y[vi, vj])
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in tir.serial(4):
                with tir.block("C"):
                    vi = tir.axis.spatial(128, i)
                    vj = tir.axis.spatial(128, j_0 * 4 + ax0)
                    tir.reads(Y[vi, vj])
                    tir.writes(C[vi, vj])
                    C[vi, vj] = tir.max(Y[vi, vj], tir.float32(0))
```

## Another way to create and interact with TensorIR

### Intro to Tensor Expression

Tensor expression (te) is a domain-specific language that describes a sequence of computations via an expression like API.

```py
from tvm import te

A = te.placeholder((128, 128), "float32", name="A")
B = te.placeholder((128, 128), "float32", name="B")
k = te.reduce_axis((0, 128), "k")
Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y")
C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C")
```

Here `te.compute` takes the signature `te.compute(output_shape, fcompute)`. And the fcompute function (the lambda function in the example) describes how we want to compute the value of each element Y\[i, j] for a given index.

In this particular case, we want to create a function with two input parameters (A, B) and one output parameter (C).

```py
te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"})
MyModuleFromTE = tvm.IRModule({"mm_relu": te_func})
IPython.display.Code(MyModuleFromTE.script(), language="python")
-------->
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def mm_relu(A: tir.Buffer[(128, 128), "float32"], B: tir.Buffer[(128, 128), "float32"], C: tir.Buffer[(128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([128, 128], dtype="float32")
        for i0, i1, i2 in tir.grid(128, 128, 128):
            with tir.block("Y"):
                i, j, k = tir.axis.remap("SSR", [i0, i1, i2])
                tir.reads(A[i, k], B[k, j])
                tir.writes(Y[i, j])
                with tir.init():
                    Y[i, j] = tir.float32(0)
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
        for i0, i1 in tir.grid(128, 128):
            with tir.block("C"):
                i, j = tir.axis.remap("SS", [i0, i1])
                tir.reads(Y[i, j])
                tir.writes(C[i, j])
                C[i, j] = tir.max(Y[i, j], tir.float32(0))
```

The tensor expression API provides a helpful tool to generate TensorIR functions for a given higher-level input.
