05-GPU MatMul and Compilers

本文最后更新于:August 22, 2025 pm

This is a lecture note of the course CSE 234 - Data Systems for ML - LE [A00].

1 GPU MatMul V1

To calculate C=A×BC=A\times B, we can write a kernel like this

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
int N = 1024;
dim3 threadsPerBlock(32, 32, 1);
dim3 numBlocks(N/32, N/32, 1);

matmul<<<numBlocks, threadsPerBlock>>>(A, B, C);

__gloabl__ void mm(float A[N][N], float B[N][N], float C[N][N]){
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;

float result = 0;
for (int k=0; k<N; ++k){
// standard dot product
result += A[x][k] * B[k][y];
}
C[x][y] = result;
}

This is a pretty naive kernel, we pick a row and a column, then we apply dot product using the for loop.

But if we analyze the memory operations, we found that

  • Each thread reads 2N2N elements (from for loop)
  • number of threads is N2N^2
  • Total memory access is N2×2N=2N3N^2\times 2N=2N^3

This is highly memory intensive.

The thread is two-dimensional, (0, 0) and (1, 0) are two different threads, although threadIdx.y are the same, different threadIdx.x indicates they are different threads.

2 GPU MatMul V1.5 Thread Tiling

From last note we know that tiling the for loop can reduce memory operations. GPU has three memory hierarchies, (1) HBM, (2) Shared Memory, (3) Registers. Let’s first try to tile between (1) and (2)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
__global__ void mm(float A[N][N], float B[N][N], float C[N][N]){
int ybase = blockIdx.y * blockDim.y + threadIdx.y;
int xbase = blockIdx.x * blockDim.x + threadIdx.x;

// calculate submatrix result with shape V x V
float c[V][V] = {0};
float a[N], b[N];
for (int x = 0; x < V; ++x){
// each thread processes V sub rows
// (from xbase * V + 0 to xbase * V + V-1)
// read one row from A
a[:] = A[xbase * V + x, :];
for (int y=0; y < V; ++y){
// read one column from B
b[:] = B[:, ybase * V + y];
for (int k=0; k<N; ++k){
// dot product
c[x][y] += a[k] * b[k];
}
}
}
C[xbase * V:xbase * V + V, ybase * V:ybase * V + V ] = c;
}

Now we are reading one row and one column at a time, each thread is responsible for V rows and V columns. Let’s calculate the memory operations again.

  • Each thread reads NV+NV2NV + NV^2 times, the first term comes from the outer loop, the second term comes from the inner loop.
  • Number of threads: (N/V)(N/V)=N2/V2(N / V)(N / V) = N^2/V^2
  • The total memory access is (N2/V2)(NV+NV2)=N3/V+N3(N^2/V^2)(NV+NV^2)=N^3/V+N^3

which is slightly better than V1.

3 GPU MatMul V2 with Partial Sum

The last version shows a significant redundant operation, where we read b[:]=B[:, ybase * V + y] repeatedly, you can see that, when we start a new outer loop, the inner loop will read matrix B over and over again, this is unnecessary.

One improvement idea is to use the partial sum. Suppose we have two matrices XX and YY, we can split XX to [X1,X2][X_1, X_2] and YY to [Y1Y2]\begin{bmatrix}Y_1\\Y_2\end{bmatrix}, the multiplication will be

X×Y=[X1,X2][Y1Y2]=X1Y1+X2Y2X\times Y = [X_1, X_2]\begin{bmatrix}Y_1\\Y_2\end{bmatrix} = X_1Y_1 + X_2Y_2

The extreme case is that we split XX to column vectors and YY to row vectors. Then the multiplication between a column vector (N×1N\times 1) and a row vector (1×N1\times N) will generate a matrix N×NN\times N, the final result will be the summation of all generated matrices.

X×Y=i=1NxiyiX\times Y = \sum_{i=1}^N \mathbf{x}_i\mathbf{y}_i

The kernel could be

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
__global__ void mm(float A[N][N], float B[N][N], float C[N][N]){
int ybase = blockIdx.y * blockDim.y + threadIdx.y;
int xbase = blockIdx.x * blockDim.x + threadIdx.x;

float c[V][V] = {0};
float a[V], b[V];
for (int k=0; k<N; ++k){
// read one column and V rows from A (Column vec)
a[:] = A[xbase*V:xbase*V+V, k];
// read one row and V columns from B (Row vec)
b[:] = B[k, ybase*V:ybase*V+V];
// loop column vector and row vector, apply
// partial sum.
for (int y=0; y<V; ++y){
for (int x=0; x<V; ++x){
c[x][y] += a[x] * b[y];
}
}
}
C[xbase * V:xbase*V + V, ybase*V:ybase*V + V] = c;
}

This code reads row vectors and column vectors to register, and then apply partial sum. Each thread is responsible for VV rows or columns.

Let’s calculate the memory operations again

  • Each thread will read 2NV2NV (N size for loop, each iteration V rows or columns)
  • Total number of threads is (N/V)(N/V)=N2/V2(N / V)(N / V) = N^2/V^2
  • Total memory access is (N2/V2)×2NV=2N3/V(N^2/V^2)\times 2NV=2N^3 / V

This is much better because we can reduce the operations by a factor of VV.

4 GPU MatMul V3 with Shared Memory

We still have one memory space not used, which is the shared memory (SRAM). We can write the loop like this:

  • A block computes L×LL\times L sub-matrix
  • A thread computes V×VV\times V (V<LV<L) sub-matrix
  • A block reads L×SL\times S elements into shared memory, then we use threads to compute them.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
__global__ void mm(float A[N][N], float B[N][N], float C[N][N]){
__shared__ float sA[L][S], sB[S][L];
float c[V][V] = {0};
float a[V], b[V];
int yblock = blockIdx.y;
int xblock = blockIdx.x;

for (int ko=0; ko<N; ko += S){
// wait in case some threads finishes too early
__syncthreads();
// each block reads S x L elements,
// loop on S and parallel on L
sA[:, :] = A[yblock*L:yblock*L+L, ko:ko+S];
sB[:, :] = B[ko:ko+S, xblock*L:xblock*L+L];

// wait until all elements are read
__syncthreads();
for(int ki=0; ki<S; ++ki){
// loop on S and parallel on L
// get one small column vector with size V x 1
a[:] = sA[threadIdx.y * V: threadIdx.y * V + V, ki];
// get one small row vector with size 1 x V
b[:] = sB[ki, threadIdx.x * V:threadIdx.x * V + V];
// now loop the a, b to calculate partial sum
for (int y = 0; y < V; ++y){
for (int x = 0; x < V; ++x){
// small sub-matrix V x V
c[y][x] += a[y] * b[x];
}
}
}
}
int ybase = blockIdx.y * blockDim.y + threadIdx.y;
int xbase = blockIdx.x * blockDim.x + threadIdx.x;
/*
One thread will compute a small V x V matrix. When
blockIdx changes, we will move blockDim threads, so
finally, one block will generate blockDim^2 V x V matrices

Since threads are operating inside L x S sub-matrix, we
must allocate L / V threads, which means blockDim=L/V
so V x V matrix will finally compose to
V * blockDim x V * blockDim = L x L matrix
*/
C[ybase * V:ybase * V + V, xbase * V: xbase * V + V] = c;
}

In V2, we read only one row or column to form a vector and compute the partial sum, this version reads multiple rows or columns into SRAM, so we can reduce more memory IO.

We can calculate the memory access again

  • Each block will read 2LN2LN elements into shared memory
  • Number of blocks is N2/L2N^2 / L^2
  • The global memory access will be 2N3/L2N^3 / L
  • Inside shared memory, each thread reads 2VS×(N/S)=2VN2VS\times (N/S) = 2VN (because of the outer loop)
  • Total number of threads is (L2/V2)×(N2/L2)=N2/V2(L^2/V^2)\times (N^2/L^2)=N^2/V^2
  • Total shared memory access is 2N3/V2N^3/V

However, the code

1
2
sA[:, :] = A[yblock*L:yblock*L+L, ko:ko+S];
sB[:, :] = B[ko:ko+S, xblock*L:cblock*L+L];

is problematic, this code snippet will be executed by all threads, but inside one block, we only need to read it once. The solution is thread cooperative fetching.

1
2
3
4
5
6
7
8
9
10
11
sB[:, :] = B[ko:ko+S, xblock*L:xblock*L+L];
// converts to
int nthreads = blockDim.y * blockDim.x;
// flatten 2D thread index
int tid = threadIdx.y * blockDim.x + threadIdx.x;

for (int j = 0; j < L*S/nthreads; ++j){
int y = (j*nthreads + tid) / L;
int x = (j*nthreads + tid) % L;
s[y, x] = B[ko+y, yblock*L + x];
}

The indexing calculation is

1
2
y = (j * nthreads + tid) / L;
x = (j * nthreads + tid) % L;

As you can see, the / L calculates the row index, % x calculates the column index. Here, L means we switch one row by L elements.

Let’s run through an example, assume we have 4 threads (nthreads = 4), with indices (0, 0), (0, 1), (1, 0), (1, 1), then tid should be

  • 0 * 2 + 0 = 0 for (0, 0)
  • 0 * 2 + 1 = 1 for (1, 0)
  • 1 * 2 + 0 = 2 for (0, 1)
  • 1 * 2 + 1 = 3 for (1, 1)

Therefore, we flattened 2D thread indices. Next, let’s see how we read them into shared memory. Assume L=4,S=2L=4, S=2, then there are two iterations.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
j = 0

// tid=0
y = (0 + 0) / 4 = 0
x = (0 + 0) % 4 = 0
s[0, 0] = A[ko, yblock*L]

//tid=1
y = (0 + 1) / 4 = 0
x = (0 + 1) % 4 = 1
s[0, 1] = A[ko, yblock*L + 1]

//tid=2
y = (0 + 2) / 4 = 0
x = (0 + 2) % 4 = 2
s[0, 2] = A[ko, yblock*L + 2]

//tid=3
y = (0 + 3) / 4 = 0
x = (0 + 3) % 4 = 3
s[0, 3] = A[ko, yblock*L + 3]

j = 1
// tid=0
y = (1*4 + 0) / 4 = 1
x = (1*4 + 0) % 4 = 0
s[1, 0] = A[ko + 1, yblock*L + 0]

//tid=1
y = (1*4 + 1) / 4 = 1
x = (1*4 + 1) % 4 = 1
s[1, 1] = A[ko + 1, yblock*L + 1]

//tid=2
y = (1*4 + 2) / 4 = 1
x = (1*4 + 2) / 4 = 2
s[1, 2] = A[ko + 1, yblock*L + 2]

//tid=3
y = (1*4 + 3) / 4 = 1
x = (1*4 + 3) / 4 = 3
s[1, 3] = A[ko + 1, yblock*L + 3]

So basically, we use two iterations, each iteration contains 4 threads that independently read a different element.

More GPU optimizations

  • Global Memory continuous read
  • Shared memory bank conflict
  • Pipelining
  • Tensor core
  • Lower precision

5 ML Compiler

Although we have GPU MatMul V3, which is better than the previous versions, we still need to tune a lot of parameters, like the size of L, S and V, the optimal combination could be different on different devices. This indicates that we need to manually tune these parameters for every GPU. This is not elegant.

One solution is to build a compiler. The ML compilation is to automatically generate optimal configurations and code given users code and target hardware.

There are some famous compilers like XLA (Google), TVM (Academic), torch.compile, Modular (startup).

However, we are not achieving a good ML compiler, because the famous operator Flash Attention is designed by human, not discovered by compilers.

6 Intro on Triton

Writing CUDA allows you to squeeze the last bits of performance, you can use whatever data structure you want. However, this requires deep expertise, and achieving the best performance is very time-consuming. Also, CUDA codebase is complex and hard to maintain.

Triton wants to build a language between Device-specific DSL (CUDA) and a compiler, so developers can prototype ideas quickly. We may lose some flexibility, like in-operator control flow and custom data structure, but it will be more friendly than CUDA.

Triton is python native, users define tensors in SRAM, and modify them using torch-like primitives. There are two facts about Triton and CUDA block allocation.

  • Triton kernel will be mapped to a single block (SM) of threads.
  • Users will be responsible for mapping to multiple blocks.

An example of element-wise addition written by Triton is

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import triton.language as tl
import triton

@triton.jit
def _add(z_ptr, x_ptr, y_ptr, N):
offsets = tl.arange(0, N) # on SRAM
# typical array + ptr offsets to get all elements
x_ptrs = x_ptr + offsets
y_ptrs = y_ptr + offsets
z_ptrs = z_ptr + offsets
# load actual elements
# each element in different address, cooperative fetching
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
tl.store(z_ptrs, z)

N = 1024
x = torch.randn(N, device="cuda")
y = torch.randn(N, device="cuda")
z = torch.randn(N, device="cuda")
grid = (1, )
_add[grid](z, x, y, N)

As you can see, Triton accesses data with pointers, it also automatically map the program to different threads.


05-GPU MatMul and Compilers
https://jesseprince.github.io/2025/08/21/ai_infra/mlsys/4_gpumatmal/
Author
林正
Posted on
August 21, 2025
Licensed under