01-Introduction
本文最后更新于:June 23, 2025 pm
This is a lecture note of the course CSE 234 - Data Systems for ML - LE [A00].
- From UC SanDiego
- Prof. Zhang Hao
- Winter, 2025
- Link: https://podcast.ucsd.edu/watch/wi25/cse234_a00/1
1 ML Models and Primitives
For system builders, we care about how to abstract the objects. In today’s ML community, we basically have these types of models to build, they are CNN, RNN, Transformer and MoE.
But from a computation perspective, we can extract the components of these models.
- CNN: Conv, Matmul, Softmax, ReLU, Batchnorm
- RNNs: Matmul, Sigmoid, tanh
- Transformer: Matmul, Softmax, GeLU, Layernorm
- MoE: Matmul, Softmax
And we can convert convolutions to matmuls, so matmul solves almost 80% of the problem. We really need to take care of matmul and accelerate it.
Hence, our primitives include matmul, addition, subtraction, element-wise multiplication, etc. But matmul is the most important thing.
2 Dataflow Graphs
How do we represent our computation? We use Dataflow graphs. However, different ML framework has different flavors in representing graphs.
So what is a Dataflow graph? Here’s an example for
graph LR
a --> mul
b --> mul
mul --> add[add-const]
const3([3]) --> add
As you can see, there are few things used to represent the graph
- Node: represents the computation, or operator
- Edge: represents the data dependency, or the data flowing direction
Note that node also represents the output tensor and sometimes an input constant tensor.
2.1 Static Graph
Way before Pytorch, the dominant framework is TensorFlow, the characteristic of TensorFlow is its static graph, in which you build the entire graph first and then execute it.
1 |
|
Then you get a graph like this
graph LR
w --> matmul
x --> matmul
matmul --> softmax
softmax --> y
But the real training step happens after you defined all computations, or say, the entire graph
1 |
|
Also, after you defined the graph, you can’t change it. You can’t use if
statement (Although we can use some tricks called switch). But sometimes we need to change the graph based on the input.
3.1 Dynamic Graph
Pytorch uses a totally different flavor called dynamic graph, the graph is created on the fly. One major benefit is that you can add print
in the middle of development, and you can check the intermediate results. So you can create and run instead of run after creation.
In TensorFlow, you have to do something like this
1 |
|
So everything still happens in sess.run()
, and you must define everything before it.
The TensorFlow way is called symbolic and the Pytorch way is called Imperative.
We can summarize the pros and cons of both ways.
- Symbolic
- pros
- Easy to optimize for system developers
- Much more efficient, sometime 10x faster than imperative
- cons
- The way of programming might be counter-intuitive
- Hard to debug for user programs.
- Less flexible: you need to write symbols before actually doing anything.
- pros
- Imperative
- pros
- More flexible: Write one line, evaluate one line
- Easy to program and easy to debug
- cons
- Less efficient
- More difficult to optimize
- pros
However, people love Pytorch because it’s much more intuitive for programming and developing models, Pytorch wins the market and dominates today’s AI community.
3.2 Just-in-Time (JIT) Compilation
Ideally, we want to define-and-run during development and define-then-run during deployment. The way to combine them is JIT. In Pytorch, you can use torch.compile()
to turn a code snippet into static graph. However, it doesn’t work for real dynamic graphs (the code or the program itself is a dynamic graph), they depend on the input, the JIT makes no acceleration, so JIT requires static graphs program.
We can compare static and dynamic graphs as follows from system perspective
- Static graph
- Define once, optimized once, execute many times
- Execution: Once defined, all following computation will follow the defined computation.
- Dynamic graph
- Difficulty in expressing complex flow-control logic (need special design on control logics for system developers)
- Complexity of the computation graph implementation
- Difficulty in debugging. Dynamic Dataflow Graphs are also difficult to debug, since it depends on the input, you must figure out the input to reproduce the bug.
However, for language application, the input length varies, we really need dynamic programs.
So how do we handle dynamic graphs? Can we JIT the dynamic graphs?
First, we can just do define-and-run, we don’t care about the performance, we build the graph line-by-line, and we forget JIT. Another approach is to introduce control flow ops or piecewise compilation and guards.
Examples of Control flow primitives are switch and merge operator.
flowchart LR
direction TB
DataIn1["Data (Tensor)"]
Pred["Predicate (Boolean)"]
SwitchOp["Switch"]
OutF["Data (Tensor)"]
Dead1["Dead Tensor"]
DataIn1 --> SwitchOp
Pred --> SwitchOp
SwitchOp -->|T| OutF
SwitchOp -->|F| Dead1
This operator receives two arguments, data and predicate (boolean) and has two outputs: data and dead tensor
flowchart LR
direction TB
In2["Data (Tensor)"]
Dead2["Dead Tensor"]
MergeOp["Merge"]
Out2["Data (Tensor)"]
In2 --> MergeOp
Dead2 --> MergeOp
MergeOp --> Out2
This operator receives two arguments, data and dead tensor, and has one outputs: data.
So if we implement a lambda function x<y?(x+z):y**2
, we can combine the switch and the merge operator, switch takes the condition, and outputs dead tensor (None) if the condition is not satisfied, and merge can merge the dead tensor with valid tensor.
But what is the gradient of these primitives? Although we can really derive it, it’s too complicated, we can do it in simpler ways.
And that is piecewise compilation.
Case 1. For input with constant ( will change). So, in order to make this work, we just compile all possible inputs, all possible .
Case 2. A graph which is static, then dynamic, then static. We insert guard between static parts. We compile the static part, then dynamic part is left to run in pure Python, then the result is fed to the static part, and the second static part is also compiled.