05-GPU MatMul and Compilers
本文最后更新于:August 22, 2025 pm
This is a lecture note of the course CSE 234 - Data Systems for ML - LE [A00].
- From UC SanDiego
- Prof. Zhang Hao
- Winter, 2025
- Link: https://podcast.ucsd.edu/watch/wi25/cse234_a00/1
1 GPU MatMul V1
To calculate , we can write a kernel like this
1 |
|
This is a pretty naive kernel, we pick a row and a column, then we apply dot product using the for loop.
But if we analyze the memory operations, we found that
- Each thread reads elements (from for loop)
- number of threads is
- Total memory access is
This is highly memory intensive.
The thread is two-dimensional, (0, 0) and (1, 0) are two different threads, although threadIdx.y
are the same, different threadIdx.x
indicates they are different threads.
2 GPU MatMul V1.5 Thread Tiling
From last note we know that tiling the for loop can reduce memory operations. GPU has three memory hierarchies, (1) HBM, (2) Shared Memory, (3) Registers. Let’s first try to tile between (1) and (2)
1 |
|
Now we are reading one row and one column at a time, each thread is responsible for V rows and V columns. Let’s calculate the memory operations again.
- Each thread reads times, the first term comes from the outer loop, the second term comes from the inner loop.
- Number of threads:
- The total memory access is
which is slightly better than V1.
3 GPU MatMul V2 with Partial Sum
The last version shows a significant redundant operation, where we read b[:]=B[:, ybase * V + y]
repeatedly, you can see that, when we start a new outer loop, the inner loop will read matrix B
over and over again, this is unnecessary.
One improvement idea is to use the partial sum. Suppose we have two matrices and , we can split to and to , the multiplication will be
The extreme case is that we split to column vectors and to row vectors. Then the multiplication between a column vector () and a row vector () will generate a matrix , the final result will be the summation of all generated matrices.
The kernel could be
1 |
|
This code reads row vectors and column vectors to register, and then apply partial sum. Each thread is responsible for rows or columns.
Let’s calculate the memory operations again
- Each thread will read (N size for loop, each iteration V rows or columns)
- Total number of threads is
- Total memory access is
This is much better because we can reduce the operations by a factor of .
4 GPU MatMul V3 with Shared Memory
We still have one memory space not used, which is the shared memory (SRAM). We can write the loop like this:
- A block computes sub-matrix
- A thread computes () sub-matrix
- A block reads elements into shared memory, then we use threads to compute them.
1 |
|
In V2, we read only one row or column to form a vector and compute the partial sum, this version reads multiple rows or columns into SRAM, so we can reduce more memory IO.
We can calculate the memory access again
- Each block will read elements into shared memory
- Number of blocks is
- The global memory access will be
- Inside shared memory, each thread reads (because of the outer loop)
- Total number of threads is
- Total shared memory access is
However, the code
1 |
|
is problematic, this code snippet will be executed by all threads, but inside one block, we only need to read it once. The solution is thread cooperative fetching.
1 |
|
The indexing calculation is
1 |
|
As you can see, the / L
calculates the row index, % x
calculates the column index. Here, L
means we switch one row by L
elements.
Let’s run through an example, assume we have 4 threads (nthreads = 4
), with indices (0, 0), (0, 1), (1, 0), (1, 1), then tid
should be
- 0 * 2 + 0 = 0 for (0, 0)
- 0 * 2 + 1 = 1 for (1, 0)
- 1 * 2 + 0 = 2 for (0, 1)
- 1 * 2 + 1 = 3 for (1, 1)
Therefore, we flattened 2D thread indices. Next, let’s see how we read them into shared memory. Assume , then there are two iterations.
1 |
|
So basically, we use two iterations, each iteration contains 4 threads that independently read a different element.
More GPU optimizations
- Global Memory continuous read
- Shared memory bank conflict
- Pipelining
- Tensor core
- Lower precision
5 ML Compiler
Although we have GPU MatMul V3, which is better than the previous versions, we still need to tune a lot of parameters, like the size of L
, S
and V
, the optimal combination could be different on different devices. This indicates that we need to manually tune these parameters for every GPU. This is not elegant.
One solution is to build a compiler. The ML compilation is to automatically generate optimal configurations and code given users code and target hardware.
There are some famous compilers like XLA (Google), TVM (Academic), torch.compile, Modular (startup).
However, we are not achieving a good ML compiler, because the famous operator Flash Attention is designed by human, not discovered by compilers.
6 Intro on Triton
Writing CUDA allows you to squeeze the last bits of performance, you can use whatever data structure you want. However, this requires deep expertise, and achieving the best performance is very time-consuming. Also, CUDA codebase is complex and hard to maintain.
Triton wants to build a language between Device-specific DSL (CUDA) and a compiler, so developers can prototype ideas quickly. We may lose some flexibility, like in-operator control flow and custom data structure, but it will be more friendly than CUDA.
Triton is python native, users define tensors in SRAM, and modify them using torch-like primitives. There are two facts about Triton and CUDA block allocation.
- Triton kernel will be mapped to a single block (SM) of threads.
- Users will be responsible for mapping to multiple blocks.
An example of element-wise addition written by Triton is
1 |
|
As you can see, Triton accesses data with pointers, it also automatically map the program to different threads.