02-Behind ML Framework

本文最后更新于:June 26, 2025 pm

This is a lecture note of the course CSE 234 - Data Systems for ML - LE [A00].

1 Auto Differentiation (AD)

Given a computation, we will first construct a graph for it, take

y=f(x1,x2)=ln(x1)+x1x2sinx2y = f(x_1,x_2) = \ln(x_1)+x_1x_2 - \sin x_2

as an example, the graph may look like this

graph LR
x_1 --> v_1
x_2 --> v_2
v_1 --> v4["(mul)v_4"]
v_2 --> v4["(mul)v_4"]
v_1 --> |ln| v_3
v_3 --> v_6
v4 --> v_6
v_2 --> |sin| v_5
v_6 --> |"(+)"| v_7
v_5 --> |"(-)"| v_7
v_7 --> y

When training NN, we always want the derivative of the value on the left side of the graph, e.g. yx1\frac{\partial y}{\partial x_1}, there are two ways to get the derivative, Forward and Backward.

Forward Mode AD. In forward mode AD, we define

v˙i=vixjj{1,2}\dot v_i = \frac{\partial v_i}{\partial x_j}\quad j\in \{1,2\}

We then calculate each v˙i\dot v_i for the variable x1x_1 in the graph:

v˙1=1v˙2=0v˙3=(ln(v1))=v˙1/v1v˙4=(v1v2)=v˙1v2+v˙2v1=v˙1v2v˙5=(sinv2)=v˙2cosv2=0v˙6=(v3+v4)=v˙3+v˙4v˙7=(v6v5)=v˙6v˙5\begin{aligned} &\dot v_1 = 1\\ &\dot v_2 = 0\\ &\dot v_3 = (\ln (v_1))' = \dot v_1 / v_1\\ &\dot v_4 = (v_1v_2)' = \dot v_1 v_2 + \dot v_2 v_1 = \dot v_1 v_2 \\ &\dot v_5 = (\sin v_2)' = \dot v_2 \cos v_2 = 0 \\ &\dot v_6 = (v_3+v_4)' = \dot v_3 + \dot v_4 \\ &\dot v_7 = (v_6-v_5)' = \dot v_6 - \dot v_5 \end{aligned}

So we start from the very left of the graph, and calculate the derivative with respect to the input of that node, we finally reach the right side of the graph and get yx1\frac{\partial y}{\partial x_1}. However, you may have noticed, we need to do the same thing for x2x_2, and in modern NNs, input side is extremely high dimensional, which means you must calculate many times in order to get all partial derivatives.

Backward Mode AD. In backward mode AD, we define

vˉi=yvi\bar v_i = \frac{\partial y}{\partial v_i}

We then calculate each vˉi\bar v_i in reverse order of the graph, starting from the very right side of the graph

vˉ7=yv7=1vˉ6=yv6=yv7v7v6=vˉ7v7v6=1vˉ5=vˉ7v7v5=1vˉ4=vˉ6v6v4=1vˉ3=vˉ6v6v3=1vˉ2=vˉ5v5v2+vˉ4v4v2=vˉ5cosv2+vˉ4(v1v2)v2=vˉ5cosv2+v1vˉ4vˉ1=vˉ4v4v1+vˉ3v3v1=vˉ4v2+vˉ31v1\begin{aligned} &\bar v_7 = \frac{\partial y}{\partial v_7} = 1\\ &\bar v_6 = \frac{\partial y}{\partial v_6} = \frac{\partial y}{\partial v_7}\frac{\partial v_7}{\partial v_6} = \bar v_7 \frac{\partial v_7}{\partial v_6}=1\\ &\bar v_5 = \bar v_7\frac{\partial v_7}{\partial v_5} = -1\\ &\bar v_4 = \bar v_6\frac{\partial v_6}{\partial v_4} = 1\\ &\bar v_3 = \bar v_6\frac{\partial v_6}{\partial v_3} = 1\\ &\bar v_2 = \bar v_5\frac{\partial v_5}{\partial v_2} + \bar v_4\frac{\partial v_4}{\partial v_2} = \bar v_5\cos v_2 + \bar v_4 \frac{\partial (v_1v_2)}{\partial v_2}=\bar v_5\cos v_2 + v_1\bar v_4\\ &\bar v_1 = \bar v_4\frac{\partial v_4}{\partial v_1} + \bar v_3\frac{\partial v_3}{\partial v_1}=\bar v_4v_2+\bar v_3\frac{1}{v_1} \end{aligned}

Here, you may notice that for nodes emitting more than one branch, the derivative is the sum of the outcome node.

graph LR
v_1 --> v_2
v_1 --> v_3
v_2 --> v_4
v_3 --> v_4
v_4 --> y

The derivative is

vˉ1=yv1=f(v2,v3)v1=f(v2,v3)v2v2v1+f(v2,v3)v3v3v1=vˉ2v2v1+vˉ3v3v1\bar v_1 = \frac{\partial y}{\partial v_1} = \frac{\partial f(v_2,v_3)}{\partial v_1}=\frac{\partial f(v_2,v_3)}{\partial v_2}\frac{\partial v_2}{\partial v_1} + \frac{\partial f(v_2,v_3)}{\partial v_3}\frac{\partial v_3}{\partial v_1} = \bar v_2\frac{\partial v_2}{\partial v_1} + \bar v_3 \frac{\partial v_3}{\partial v_1}

So for a viv_i used by multiple consumers, we have

vˉi=jnext nodesvˉijvˉij=vˉjvjvi(1)\bar v_i = \sum_{j\in \text{next nodes}}\bar v_{i\rightarrow j}\quad \bar v_{i\rightarrow j}=\bar v_j \frac{\partial v_j}{\partial v_i}\tag{1}

So we start from the output node and derive gradients all the way back to the input nodes, notice that we calculated gradients for x1x_1 and x2x_2 in a single run. This is because we only have one output node.

In ML problems, we always have one output node, which is the scalar loss, and extremely high dimensional inputs. Hence, using backward AD is much more efficient.

2 Constructing Backward Graphs

The pseudocode for backward graph computation is shown as follows

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def compute_gradient(output_node):
# Initialize gradients: the gradient of the output w.r.t. itself is 1
node_to_grad = {output_node: [1.0]}

# Traverse nodes in reverse topological order
for node in reverse_topological_order(output_node):
# Accumulate total gradient flowing into this node
total_grad = sum(node_to_grad[node]) # Equ. (1)

# Propagate gradients to inputs
for input_node in inputs(node):
# Compute local gradient contribution via chain rule
local_grad = total_grad * derivative(node, input_node) # Equ. (1)

# Accumulate gradient for input node
if input_node not in node_to_grad:
node_to_grad[input_node] = []
node_to_grad[input_node].append(local_grad)

# Return gradient w.r.t. the original input(s)
return node_to_grad[input_node]

The logic behind it is straightforward, we start from the output node, accumulate total gradient in case it has multiple outputs according to Equation (1), then we calculate vˉki=vˉivivk\bar v_{k\rightarrow i}=\bar v_i\frac{\partial v_i}{\partial v_k}, where kk is the input node of ii, we do this because we get vˉi\bar v_i from accumulation, so we can calculate the gradient of these previous nodes.

The dictionary node_to_grad is initialized as {output_node: [1.0]} because the output connects to y, and it’s always an identity function.

Let’s run through an example

graph LR
v_1 --> |exp| v_2
1 --> v_3
v_2 --> v_3
v_2 --> v4["(mul)v_4"]
v_3 --> v4["(mul)v_4"]

The math equation is v4=exp(v1)(1+exp(v1))v_4 = \exp (v_1)(1+\exp (v_1)), the gradient with respect to v1v_1 is exp(v1)(1+exp(v1))+exp(2v1)\exp(v_1)(1+\exp(v_1)) + \exp(2v_1).

1
2
3
4
5
6
7
8
9
10
11
# iter 0
node = v_4
total_grad = 1.0
# sub_iter 0
input_node = v_3
local_grad = 1.0 * derivative(v_4, v_3) = v_2
node_to_grad = {v_4: [1.0], v_3: [v_2]}
# sub_iter 1
input_node = v_2
local_grad = 1.0 * derivative(v_4, v_2) = v_3
node_to_grad = {v_4: [1.0], v_3: [v_2], v_2: [v_3]}

Next we go one step back to node v3v_3

1
2
3
4
5
6
7
8
9
10
11
# iter 1
node = v_3
total_grad = v_2
# sub_iter 0
input_node = v_2
local_grad = v_2 * derivative(v_3, v_2) = v_2
node_to_grad = {v_4: [1.0], v_3: [v_2], v_2: [v_3, v_2]}
# sub_iter 1
input_node = constant
local_grad = v_2 * derivative(v_3, constant) = 1
node_to_grad = {v_4: [1.0], v_3: [v_2], v_2: [v_3, v_2], constant: [1]}

Next we go back to node v2v_2

1
2
3
4
5
6
7
8
9
10
11
12
13
# iter 2
node = v_2
total_grad = v_3 + v_2
# sub_iter 0
input_node = v_1
local_grad = (v_3+v_2) * derivative(v_3, v_1) = (v_3+v_2)exp(v_1)
node_to_grad = {
v_4: [1.0],
v_3: [v_2],
v_2: [v_3, v_2],
v_1: [(v_3+v_2)exp(v_1)],
constant: [1]
}

Finally, we take the sum of that list in the dictionary, the v1v_1 is

v1=(v3+v2)exp(v1)=(1+exp(v1)+exp(v1))exp(v1)=exp(v1)(1+exp(v1))+exp(2v1)\begin{aligned} v_1 &= (v_3 + v_2)\exp(v_1) \\ &= (1+\exp(v_1)+\exp(v_1))\exp(v_1) \\ &= \exp(v_1)(1+\exp(v_1)) + \exp(2v_1) \end{aligned}

which is exactly the same to the gradient we calculated analytically.

In modern ML framework like PyTorch, they also construct a backward graph that is connected to the forward graph (since we will use some forward results to calculate the gradient).

Homework: How to derive the gradient of softmax cross entropy

We know that for xRdx\in \mathbb{R}^d, we have

yi=softmax(x)i=exij=1dexjy_i = \mathrm{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^d e^{x_j}}

then cross entropy is calculated by

L=igtilog(yi)L = -\sum_i \text{gt}_i\log(y_i)

To calculate the gradient of xkx_k

Lxk=igtixklog(yi)\frac{\partial L}{\partial x_k} = -\sum_i\text{gt}_i\frac{\partial }{\partial x_k}\log(y_i)

when iki\neq k

xklog(yi)=xk(xilog(j=1dexj))=exkj=1dexj=yk\begin{aligned}\frac{\partial}{\partial x_k}\log(y_i) &=\frac{\partial}{\partial x_k}(x_i-\log(\sum_{j=1}^d e^{x_j})) \\&=-\frac{e^{x_k}}{\sum_{j=1}^de^{x_j}}\\&=-y_k\end{aligned}

when i=ki=k

xilog(yi)=xi(xilog(j=1dexj))=1exij=1dexj=1yk\begin{aligned}\frac{\partial}{\partial x_i}\log(y_i) &=\frac{\partial}{\partial x_i}(x_i-\log(\sum_{j=1}^d e^{x_j})) \\&=1-\frac{e^{x_i}}{\sum_{j=1}^de^{x_j}}\\&=1-y_k\end{aligned}

So the gradient of the softmax cross entropy is

Lxk=igtixklog(yi)=(gtk(1yk)ykikgti)\begin{aligned}\frac{\partial L}{\partial x_k} &= -\sum_i\text{gt}_i\frac{\partial }{\partial x_k}\log(y_i)\\&=-(\text{gt}_k(1-y_k)-y_k\sum_{i\neq k}\text{gt}_i)\end{aligned}

Note that for probability distribution,

igti=1ikgti=1gtk\sum_i \text{gt}_i = 1 \Rightarrow \sum_{i\neq k}\text{gt}_i = 1-\text{gt}_k

then

ykikgtigtk(1yk)=yk(1gtk)gtk(1yk)=ykgtk\begin{aligned}y_k\sum_{i\neq k}\text{gt}_i - \text{gt}_k(1-y_k)&=y_k(1-\text{gt}_k)-\text{gt}_k(1-y_k)\\&=y_k-\text{gt}_k\end{aligned}

This indicates that the gradient of the cross entropy loss is the difference between predicted probability and the ground truth.

3 ML Framework Hierarchy

The hierarchy of a modern ML framework is like this

graph BT

d1["Operator optimization/compilation"]
d2["Runtime: schedule/memory"]
d3["Parallelization"]
d4["Graph optimization"]
d5["Autodiff"]
d6["Dataflow graph"]

d1 --> d2
d2 --> d3
d3 --> d4
d4 --> d5
d5 --> d6

3.1 Graph Optimization

The goal of graph optimization is to make the graph faster, formally, given GG, we want to rewrite the GG to GG', such that GG' runs faster than G.

A case is that if a Conv2D layer is connected to a BatchNorm layer, then these two operators can be fused into one layer. After fusion, we can have a smaller graph and the speed will be faster.

Another case study is Attention, instead of

1
2
3
Q = matmul(W_q, h)
K = matmul(W_k, h)
V = matmul(W_v, h)

We compute qkv in a merged fashion

1
qkv = matmul(concat(W_q, W_k, W_v), h)

then the variable qkv can be split into three variables. You may ask why this is faster, and the answer lies in the metric Arithmetic Intensity (AI), defined by

AI=#ops#Memory access bytes(2)\text{AI} = \frac{\# \text{ops}}{\# \text{Memory access bytes}}\tag{2}

Higher AI means higher efficiency, because more time is spent on computation instead of memory access. Let’s compare two algorithm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
void add(int n, float* A, float* B, float* C)
{
for (int i=0; i<n; i++)
{
C[i] = A[i] + B[i];
}
}

void mul(int n, float* A, float* B, float* C)
{
for (int i=0; i<n; i++)
{
C[i] = A[i] * B[i];
}
}

float* A, *B, *C, *D, *E, *tmp1, *tmp2;
// calculate E = D + ((A + B) * C)
add(n, A, B, tmp1);
mul(n, tmp1, C, tmp2);
add(n, tmp2, D, E)

Inside function add and mul, we see that one computation in for loop requires three memory access and one computation (either add or mulipy), so AI=1/3AI=1/3. The three function calls have average AI 1/31/3.

Now let’s fuse them together

1
2
3
4
5
6
7
void fused(int n, float* A, float* B, float* C, float* D, float* E)
{
for (int i=0; i<n; i++)
{
E[i] = D[i] + (A[i] + B[i]) * C[i]
}
}

Now we access 5 times, and apply 3 arithmetic computation, then AI=3/5AI=3/5, which is higher than 1/31/3. How do we perform graph optimization? We can write rules and templates, e.g., you can scan the whole graph, and when you find operators that can be fused, you can fuse them. Another technique is called Auto Discovery, it allows automatic discovery of operators that can be fused.

3.2 Parallelization

Dive one layer deeper, we arrive at the parallelization layer, where we want to parallelize the computation graph. In this layer, people care about how to partition the graph, and how does each process communicate, etc.

  • How to partition
  • How to communicate
  • How to schedule
  • Consistency (The final result doesn’t change after parallelization)
  • How to auto-parallelize (Algorithm builders are agnostic to the system)

3.3 Runtime and Scheduling

One layer down, the overall goal is to schedule the memory/compute/communication that makes the graph run fast as possible, and overlap communication with compute if possible, also we need to satisfy the memory constraints.

3.4 Operator optimization

We want the fastest implementation of the commonly used operators, like matmul. We may optimize it for different hardware, precision and shapes.

In this layer, the goal is to maximize the AI, which is defined in Equation (1).

4 Operator Optimization: A High Level View

4.1 Vectorization

1
2
3
4
5
6
7
8
9
10
11
12
13
14
float A[256], B[256], C[256]
// unvectorized
for (int i = 0; i < 256; ++i)
{
C[i] = A[i] + B[i];
}

// vectorized
for (int i = 0; i < 256; i += 4) {
__m128 a = _mm_load_ps(&A[i]); // Load 4 floats
__m128 b = _mm_load_ps(&B[i]); // Load 4 floats
__m128 c = _mm_add_ps(a, b); // SIMD add
_mm_store_ps(&C[i], c); // Store result
}

The vectorized version loads 4 floats at a time, add 4 floats as vector add, then store the final results into array C.

4.2 Data Layout

Different layout of the data in the memory can also make a difference in efficiency.

Normally, a matrix is stored sequentially in the memory, we either access the matrix in row major (A[i,j]=A.data[i*A.shape[1]+j]) or column major (A[i,j]=A.data[j*A.shape[0]+i])

However, ML frameworks store data in strides format. In PyTorch, Tensor.stride(i) tells you how many elements need to be skipped in memory to move one index in the i’th dimension of the tensor. e.g.

1
2
3
4
5
6
7
8
9
>>> a = torch.randn(3,5)
>>> a
tensor([[-1.5500, -1.6692, -1.1693, -1.0516, 0.1409],
[ 0.9175, 1.4273, -1.5217, 0.6138, -0.5706],
[-0.9240, 1.6548, -0.8674, -0.8225, 0.4161]])
>>> a.stride(0)
5
>>> a.stride(1)
1

In the above example, move one index on the 0’th dimension needs to skip 5 elements in the memory, and 1 element is skipped when moving on the 1’th dimension.

It’s still a row major, right? If we set

1
2
3
# Not a PyTorch API
a.strides[0] = 1
a.strides[1] = A.shape[0]

This becomes column major, so strides is convenient to represent row major and column major.

Quick question: A tensor of shape [1, 2, 3, 4] stored in row major, what is its strides?

A tensor of shape [1, 2, 3, 4] has 1×2×3×4=241\times 2\times 3\times 4=24 elements in total, it’s in row major, then moving on the 0’th dimension will skip all the elements in the rest dimensions of the tensor, that is 2×3×4=242\times 3\times 4=24, and moving on the first dimension will skip 3×4=123\times 4=12 elements, we can do the same thing for the 3’rd dimension, and finally on the last dimension we only move 1 element for 1 index.

1
2
3
>>> a = torch.randn(1,2,3,4)
>>> a.stride()
(24, 12, 4, 1)

02-Behind ML Framework
https://jesseprince.github.io/2025/06/23/ai_infra/mlsys/1_mlframe/
Author
林正
Posted on
June 23, 2025
Licensed under