03-Optimization on Operator and Matrix Multiplication

本文最后更新于:June 28, 2025 pm

This is a lecture note of the course CSE 234 - Data Systems for ML - LE [A00].

1 The Benefit of Stride Layout

From last lecture we know that PyTorch uses Tensor.stride to represent the memory layout of a tensor, there are lots of benefits of using such layout, we can enable “zero-copy” operators.

The first operator is Tensor.view, suppose we have a tensor with shape 3, 2, if we call view(2, 3), we only need to set the stride from (2, 1) to (3, 1), we don’t need to create a new tensor and copy the elements.

Another operator is slicing, for example, a tensor a with shape 4, 5, we want to slice a[0:3, 1:3], what we do is change the offset to +1 and reduce the shape to [3, 2]. The formula for memory accessing is (row major)

a[i][j]=a[offset+istride(0)+jstride(1)]a[i][j] = a'[\text{offset}+i\cdot \mathrm{stride}(0)+j\cdot \mathrm{stride}(1)]

where aa' is the 1D layout. So setting offset=1\text{offset}=1 will make it always skip the first element in a row (1:3), and we constrain i<3,j<2i<3,j<2, then this will work. In summary, setting offset and new constraints to index will enable zero-copy slicing.

tensor.Transpose can also be implemented with strides, see this example

1
2
3
4
5
6
>>> a = torch.randn(1,2,3,4)
>>> a.stride()
(24, 12, 4, 1)
>>> a = a.permute(1,2,3,0)
>>> a.stride()
(12, 4, 1, 24)

The above operation moves the first dimension to the last dimension, PyTorch changes its stride (also the index constraints).

One addition operator is broadcast. e.g.

1
2
3
a = torch.randn(4, 3)
b = torch.randn(3)
a + b

PyTorch will add an addition dimension to b and copy the elements in b for 4 times, so eventually the shape of b matches that of a. If we call stride() for b, we will see

1
2
>>> b.stride()
(1,)

To broadcast, we just add a “0” stride to b, then the stride of b becomes 0, 1. When accessing b, we are doing

b[i][j]=b[i0+j1]=b[j]b[i][j]=b'[i\cdot 0+j\cdot 1] = b'[j]

So whatever the index i is, we just go back to the beginning of the original tensor, this simulates the situation where the tensor is copied multiple times on row dimension.

Home Exercise: Swaping tiles
Modify matrix

A=[0123456789101112131415]A=[0189451213231011671415]A=\begin{bmatrix}0 & 1 & 2 & 3 \\4 & 5 & 6 & 7 \\8 & 9 & 10 & 11 \\12 & 13 & 14 & 15\end{bmatrix}\rightarrowA'=\begin{bmatrix}0 & 1 & 8 & 9 \\4 & 5 & 12 & 13 \\2 & 3 & 10 & 11 \\6 & 7 & 14 & 15\end{bmatrix}

by modifying its strides.

Solution: (Not sure whether this satisfies professor’s requirements)
Let ii denote the index on row dimension, jj be the index on column dimension.

By observing the matrix, we note that when accessing within a row (on column dimension), the element will first be continuous if j<2j<2 and then jump to the value 6 elements ahead. This holds true for all rows. Then we can design the column access as

j+(j//2)6j + (j//2)\cdot 6

For row access, the first two rows follows the original access pattern (the second row skips 4 elements), however, when i>=2i>=2, we jump back to the first two rows, and the starting point is the third element in that row, so we may write row access like this

4i(1i//2)+(i//2)(4(i2)+2)4i(1-i//2) + (i//2)(4(i-2)+2)

Putting them together

A[i][j]=Am[4i(1i//2)i<2 normal access+(i//2)if i>=2(4(i2)+2)Jump back and offset+j+(j//2)6Skip 6 elements if j>=2]A'[i][j] = A_m[\underbrace{4i(1-i//2)}_{i<2\text{ normal access}} + \underbrace{(i//2)}_{\text{if }i>=2}\underbrace{(4(i-2)+2)}_{\text{Jump back and offset}}+j + \underbrace{(j//2)\cdot 6}_{\text{Skip 6 elements if }j>=2}]

where AmA_m is the array stored in the memory. Let’s verify with code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> a = torch.arange(16)
>>> a
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
>>> def get(i, j):
... return a[4*i*(1-i//2)+(i//2)*((i-2)*4+2)+j+(j//2)*6]
...
>>> b = torch.zeros(4, 4)
>>> for i in range(4):
... for j in range(4):
... b[i][j] = get(i, j)
...
>>> b
tensor([[ 0., 1., 8., 9.],
[ 4., 5., 12., 13.],
[ 2., 3., 10., 11.],
[ 6., 7., 14., 15.]])

However, stride is not a perfect solution, after some operations (e.g., views), the memory access may become not continuous, but many vectorized ops need continuous storage. So if you run into errors, you may need to call Tensor.contiguous, this rearranges the memory layout to make it continuous.

Quick recap: How do we make operators run faster?

  • Vectorization
    • Platform-specific vectorized functions
    • Reduce seek time
  • Data layout
    • stride format
    • zero copy
    • Fast tensor manipulation through stride
  • Parallelization on CPUs

2 Matrix Multiplication

We know this is so important because ML does so many matrix multiplications. How do we implement a vanilla matrix multiplication?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// Calculate C = A @ B.T
float A[n][n], B[n][n], C[n][n];

for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
C[i][j] = 0;
for (int k = 0; k < n; k++)
{
C[i][j] += A[i][k] * B[j][k]
}
}
}

We just take three nested loops, the first loop reads row of matrix A, the second loop reads column of B.T, which is the row of matrix B. The third loop is used to reduce the elements inside a row/column vector.

Obviously, the complexity is O(n3)O(n^3). To make this code faster, we also need to maximize AI.

When we execute codes, we have to move variables between different memory hierarchies. In computer, memory is a pyramid, higher level memory is faster but also much more expensive.

  • Registers
  • L1 cache
  • L2 cache
  • DRAM

Let’s consider the memory hierarchy with one register and one DRAM. Let’s annotate the code above.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// Calculate C = A @ B.T
dram float A[n][n], B[n][n], C[n][n];

for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
register float c = 0;
for (int k = 0; k < n; k++)
{
register float a = A[i][k]
register float b = B[j][k]
c += a * b
}
C[i][j] = c;
}
}

The annotated code clearly indicates the data flow from DRAM to registers. For two n×nn\times n matrices, we read n3n^3 times for a, n3n^3 times for b and write the results into C for n2n^2 times. We need 3 registers for a, b and c, and the read cost is

2n3speed(DRAMregister)2n^3\cdot \mathrm{speed}(\text{DRAM}\rightarrow \text{register})

The way to optimize matrix multiplication is to tile them, the code below introduces the register tiled matrix multiplication

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// Calculate C = A @ B.T
dram float A[n/v1][n/v3][v1][v3]
dram float B[n/v2][n/v3][v2][v3]
dram float C[n/v1][n/v2][v1][v2]

for (int i = 0; i < n/v1; i++)
{
for (int j = 0; j < n/v2; j++)
{
register float c[v1][v2] = 0;
for (int k = 0; k < n/v3; k++)
{
register float a[v1][v3] = A[i][k];
register float b[v2][v3] = B[j][k];
c += dot(a, b.T)
}
C[i][j] = c;
}
}

You may first notice that matrix A, B and C are grouped or windowed, kind of like the patch partition in images. For instance, row of matrix A is separated for n/v1 groups, column of matrix A is separated for n/v3 groups.

Now, the first loop reads the row groups of A and the second loop reads the row groups of B. The third loop actually calculates a small matrix multiplication, it scans on the column groups. So there is no special trick, it’s a block matrix multiplication.

But let’s count the memory access again. We read

nv1×nv2×nv3×v1×v3=n3v2(1)\frac{n}{v1}\times \frac{n}{v2}\times \frac{n}{v3}\times v1\times v3=\frac{n^3}{v2}\tag{1}

for a. Here the factor v1×v3v1\times v3 is because we are reading multiple values from a small matrix.

We read

nv1×nv2×nv3×v2×v3=n3v1(2)\frac{n}{v1}\times \frac{n}{v2}\times \frac{n}{v3}\times v2\times v3=\frac{n^3}{v1}\tag{2}

for b. And we write

nv1×nv2×v1×v1=n2\frac{n}{v1}\times \frac{n}{v2}\times v1\times v1=n^2

for c.

But this time, we need v1v2+v2v3+v1v2v1\cdot v2+v2\cdot v3+v1\cdot v2 registers. The read cost is reduced to

(n3v1+n3v2)speed(DRAMregister)\left(\frac{n^3}{v1}+\frac{n^3}{v2}\right)\cdot \mathrm{speed}(\text{DRAM}\rightarrow \text{register})

Here we can observe some important facts

  1. The read cost is irrelevant to v3v3
  2. We must set v1v1 and v2v2 such that the register can hold them.
  3. We are trading speed with more space (register) used.

But intuitively, how did we reduce the read cost? If we take a closer look at the memory access computation in Equation (1), we see that the computation is scaled by v2v2, which means, if we don’t group the matrix B and the second loop becomes for (int j = 0; j < n; j++), this scale factor will disappear.

The memory access is reduced not because we grouped the matrix A, it’s because we read less B when we are reading A.

We read less in the second loop, or say, we reused B when reading A, this also happens when we are reading B, from Equation (2) we clearly see that grouping B contributes nothing, it’s the first loop who brings the scale factor, we reused A when we are reading B.

Let’s go back to a more complicated setting, we have registers, L1 cache and DRAM.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Calculate C = A @ B.T
dram float A[n/b1][b1][n];
dram float B[n/n2][b2][n];
dram float C[n/b1][n/b2][b1][b2];

for (int i = 0; i < n/b1; i++)
{
l1cache float a[b1][n] = A[i];
for (int j = 0; j < n/b2; j++)
{
l1cache float b[b2][n] = B[j];
// Register calculations with tiling
C[i][j] = dot(a, b.T)
}
}

You may notice that we don’t tile columns because we already found in Equation (1) and (2) tiling columns do not change the memory access cost.

The code above is called cache-aware tiling, the one we did before this is called register-aware tiling. We have changed the data movement to

  1. DRAM
  2. DRAM -> l1cache
  3. l1cache -> register

Let’s calculate the memory access cost:

  • For A: (n/b1)nb1=n2(n/b1)\cdot n\cdot b1=n^2
  • For B: (n/b1)(n/b2)b2n=n3/b1(n/b1)\cdot (n/b2)*b2*n=n^3/b1

Now we put cache-aware tiling and register-aware tiling together, we have

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// tile for both l1cache and register
dram float A[n/b1][b1/v1][n][v1];
dram float B[n/b2][b2/v2][n][v2];
dram float C[n/b1][n/b2][v1][v2];

// l1 cache
for (int i = 0; i < n/b1; i++)
{
l1cache float a[b1/v1][n][v1] = A[i];
for (int j = 0; j < n/b2; j++)
{
l1cache float b[b2/v2][n][v2] = B[j];
// register
for (int x = 0; x < b1/v1; x++)
{
for (int y = 0; y < b2/v2; y++)
{
register float c[v1][v2] = 0;
for (int k = 0; k < n; k++)
{
register float ar[v1] = a[x][k];
register float br[v2] = b[y][k];
c += dot(ar, br.T);
}
C[i,j] = c;
}
}
}
}

The cost of overall loop is

C=DRAM+Register=(n/b1)(b1/v1)nv1+(n/b1)(n/b2)(b2/v2)nv2+(n/b1)(n/b2)(b1/v1)(b2/v2)nv1+(n/b1)(n/b2)(b1/v1)(b2/v2)nv2=n2+n2/b1+n3/v2+n3/v1\begin{aligned} C &= \color{blue}{\text{DRAM}} + \color{green}{\text{Register}}\\ &=\color{blue}{(n/b1)*(b1/v1)*n*v1+(n/b1)*(n/b2)*(b2/v2)*n*v2}\\ &+\color{green}{(n/b1)*(n/b2)*(b1/v1)*(b2/v2)*n*v1}\\ &+\color{green}{(n/b1)*(n/b2)*(b1/v1)*(b2/v2)*n*v2}\\ &=\color{blue}{n^2+n^2/b1} + \color{green}{n^3/v2+n^3/v1} \end{aligned}

This happens on real CPU operations. Because on real CPUs, data might flow from Disk -> DRAM -> L2 -> L1 -> Register, so we must move our variable across at least three layers.

We can also try to overlap the reading time, e.g., while we are reading from DRAM to L2 cache, we can simultaneously read from L2 cache to L1 cache and L1 cache to register.

3 GPUs

Parallelization is really useful, people do massive parallelization on GPUs nowadays. There’s a simple concept called Single-Instruction Multiple-Data (SIMD), which means we use a single instruction on multiple data, and we can do this simultaneously.

One example is the vectorized tensor addition we explored before, we can assign an addition task to each CPU core and let them run simultaneously.

When designing CPUs, we always need areas for control and cache, ALUs will take a part of the area, like the diagram shown below.

Control ALU ALU
ALU ALU
Caches

We can improve our manufacturing techniques and ALUs will get smaller and smaller, then we can put more ALUs inside a constraint area.

But according to some physical constraints, we can’t infinitely shrink our transistors, so Moore’s Law comes to an end. One way to accelerate even more is to develop specialized hardware.

Core Control
L1 Cache
Core Control
L1 Cache
L2 Cache
L3 Cache
DRAM
CPU
...
L2 Cache
DRAM
GPU

The above diagram is a comparison between CPU and GPU, GPU has a lot of weak and specialized cores, they can do only a few things, like cores that can only do matrix multiplication. Although they are weak and specialized, we can remove some and make the cores smaller, so we can put tons of cores in one chip.

One of the most popular specialized chips is the Nvidia GPUs, they are now ubiquitous in DL. Google has Tensor Processing Unit (TPU), which is an Application-Specific Integrated Circuit (ASIC) create in mid 2010s, it’s used to train ALphaGo.

In today’s community, people are building functionaliy-specialized chips, e.g. can only compute matmul, and they are also trying to reduce the precision, e.g. FP32 -> FP16 -> FP8 -> INT8 etc. They are also trying to tune the distribution of different components for specific workloads, like the size of SRAM, cache, registers, etc.

Quantization is so important because we can gain much more computing speed with reduced precision.

The rest of the study should complete with reading product sheets from Nvidia and some startup accelerator companies.


03-Optimization on Operator and Matrix Multiplication
https://jesseprince.github.io/2025/06/28/ai_infra/mlsys/2_ops_matmul/
Author
林正
Posted on
June 28, 2025
Licensed under