When training NN, we always want the derivative of the value on the left side of the graph, e.g. ∂x1∂y, there are two ways to get the derivative, Forward and Backward.
Forward Mode AD. In forward mode AD, we define
v˙i=∂xj∂vij∈{1,2}
We then calculate each v˙i for the variable x1 in the graph:
So we start from the very left of the graph, and calculate the derivative with respect to the input of that node, we finally reach the right side of the graph and get ∂x1∂y. However, you may have noticed, we need to do the same thing for x2, and in modern NNs, input side is extremely high dimensional, which means you must calculate many times in order to get all partial derivatives.
Backward Mode AD. In backward mode AD, we define
vˉi=∂vi∂y
We then calculate each vˉi in reverse order of the graph, starting from the very right side of the graph
So we start from the output node and derive gradients all the way back to the input nodes, notice that we calculated gradients for x1 and x2 in a single run. This is because we only have one output node.
In ML problems, we always have one output node, which is the scalar loss, and extremely high dimensional inputs. Hence, using backward AD is much more efficient.
2 Constructing Backward Graphs
The pseudocode for backward graph computation is shown as follows
defcompute_gradient(output_node): # Initialize gradients: the gradient of the output w.r.t. itself is 1 node_to_grad = {output_node: [1.0]}
# Traverse nodes in reverse topological order for node in reverse_topological_order(output_node): # Accumulate total gradient flowing into this node total_grad = sum(node_to_grad[node]) # Equ. (1)
# Propagate gradients to inputs for input_node in inputs(node): # Compute local gradient contribution via chain rule local_grad = total_grad * derivative(node, input_node) # Equ. (1)
# Accumulate gradient for input node if input_node notin node_to_grad: node_to_grad[input_node] = [] node_to_grad[input_node].append(local_grad)
# Return gradient w.r.t. the original input(s) return node_to_grad[input_node]
The logic behind it is straightforward, we start from the output node, accumulate total gradient in case it has multiple outputs according to Equation (1), then we calculate vˉk→i=vˉi∂vk∂vi, where k is the input node of i, we do this because we get vˉi from accumulation, so we can calculate the gradient of these previous nodes.
The dictionary node_to_grad is initialized as {output_node: [1.0]} because the output connects to y, and it’s always an identity function.
which is exactly the same to the gradient we calculated analytically.
In modern ML framework like PyTorch, they also construct a backward graph that is connected to the forward graph (since we will use some forward results to calculate the gradient).
Homework: How to derive the gradient of softmax cross entropy
The goal of graph optimization is to make the graph faster, formally, given G, we want to rewrite the G to G′, such that G′ runs faster than G.
A case is that if a Conv2D layer is connected to a BatchNorm layer, then these two operators can be fused into one layer. After fusion, we can have a smaller graph and the speed will be faster.
Another case study is Attention, instead of
1 2 3
Q = matmul(W_q, h) K = matmul(W_k, h) V = matmul(W_v, h)
We compute qkv in a merged fashion
1
qkv = matmul(concat(W_q, W_k, W_v), h)
then the variable qkv can be split into three variables. You may ask why this is faster, and the answer lies in the metric Arithmetic Intensity (AI), defined by
AI=#Memory access bytes#ops(2)
Higher AI means higher efficiency, because more time is spent on computation instead of memory access. Let’s compare two algorithm
voidadd(int n, float* A, float* B, float* C) { for (int i=0; i<n; i++) { C[i] = A[i] + B[i]; } }
voidmul(int n, float* A, float* B, float* C) { for (int i=0; i<n; i++) { C[i] = A[i] * B[i]; } }
float* A, *B, *C, *D, *E, *tmp1, *tmp2; // calculate E = D + ((A + B) * C) add(n, A, B, tmp1); mul(n, tmp1, C, tmp2); add(n, tmp2, D, E)
Inside function add and mul, we see that one computation in for loop requires three memory access and one computation (either add or mulipy), so AI=1/3. The three function calls have average AI 1/3.
Now let’s fuse them together
1 2 3 4 5 6 7
voidfused(int n, float* A, float* B, float* C, float* D, float* E) { for (int i=0; i<n; i++) { E[i] = D[i] + (A[i] + B[i]) * C[i] } }
Now we access 5 times, and apply 3 arithmetic computation, then AI=3/5, which is higher than 1/3. How do we perform graph optimization? We can write rules and templates, e.g., you can scan the whole graph, and when you find operators that can be fused, you can fuse them. Another technique is called Auto Discovery, it allows automatic discovery of operators that can be fused.
3.2 Parallelization
Dive one layer deeper, we arrive at the parallelization layer, where we want to parallelize the computation graph. In this layer, people care about how to partition the graph, and how does each process communicate, etc.
How to partition
How to communicate
How to schedule
Consistency (The final result doesn’t change after parallelization)
How to auto-parallelize (Algorithm builders are agnostic to the system)
3.3 Runtime and Scheduling
One layer down, the overall goal is to schedule the memory/compute/communication that makes the graph run fast as possible, and overlap communication with compute if possible, also we need to satisfy the memory constraints.
3.4 Operator optimization
We want the fastest implementation of the commonly used operators, like matmul. We may optimize it for different hardware, precision and shapes.
In this layer, the goal is to maximize the AI, which is defined in Equation (1).
4 Operator Optimization: A High Level View
4.1 Vectorization
1 2 3 4 5 6 7 8 9 10 11 12 13 14
float A[256], B[256], C[256] // unvectorized for (int i = 0; i < 256; ++i) { C[i] = A[i] + B[i]; }
// vectorized for (int i = 0; i < 256; i += 4) { __m128 a = _mm_load_ps(&A[i]); // Load 4 floats __m128 b = _mm_load_ps(&B[i]); // Load 4 floats __m128 c = _mm_add_ps(a, b); // SIMD add _mm_store_ps(&C[i], c); // Store result }
The vectorized version loads 4 floats at a time, add 4 floats as vector add, then store the final results into array C.
4.2 Data Layout
Different layout of the data in the memory can also make a difference in efficiency.
Normally, a matrix is stored sequentially in the memory, we either access the matrix in row major (A[i,j]=A.data[i*A.shape[1]+j]) or column major (A[i,j]=A.data[j*A.shape[0]+i])
However, ML frameworks store data in strides format. In PyTorch, Tensor.stride(i) tells you how many elements need to be skipped in memory to move one index in the i’th dimension of the tensor. e.g.
In the above example, move one index on the 0’th dimension needs to skip 5 elements in the memory, and 1 element is skipped when moving on the 1’th dimension.
It’s still a row major, right? If we set
1 2 3
# Not a PyTorch API a.strides[0] = 1 a.strides[1] = A.shape[0]
This becomes column major, so strides is convenient to represent row major and column major.
Quick question: A tensor of shape [1, 2, 3, 4] stored in row major, what is its strides?
A tensor of shape [1, 2, 3, 4] has 1×2×3×4=24 elements in total, it’s in row major, then moving on the 0’th dimension will skip all the elements in the rest dimensions of the tensor, that is 2×3×4=24, and moving on the first dimension will skip 3×4=12 elements, we can do the same thing for the 3’rd dimension, and finally on the last dimension we only move 1 element for 1 index.
1 2 3
>>> a = torch.randn(1,2,3,4) >>> a.stride() (24, 12, 4, 1)