>>> import torch >>> a = torch.tensor([1, 2, 3, 4], requires_grad=True, dtype=torch.float32) >>> a tensor([1., 2., 3., 4.], requires_grad=True) >>> b = torch.tensor([1, 2, 1, 2], requires_grad=True, dtype=torch.float32) >>> b tensor([1., 2., 1., 2.], requires_grad=True) >>> c = torch.clamp(a * b, 1, 5) >>> c tensor([1., 4., 3., 5.], grad_fn=<ClampBackward1>) >>> d = c.sum() >>> d tensor(13., grad_fn=<SumBackward0>) >>> d.backward() >>> b.grad tensor([1., 2., 3., 0.]) >>> a.grad tensor([1., 2., 1., 0.])
In this example, we created to tensor
a=[1,2,3,4]b=[1,2,1,2]
and these two tensors are element-wise multiplied but clamped by torch.clamp. Let’s recalculate the chain rule of this computational graph step by step.
2 Mathematical Derivations
What we want the PyTorch to calculate are gradients of tensor a and tensor b, we have
Here, the clamp function has slop 1 if the value is not clamped, and the slop becomes 0 if clamped (because the value doesn’t change after that). Hence, the partial derivative of the clamp function is effectively an indicator function of whether this value is clamped.
Now we can list our gradient value
∇a0d
∇a1d
∇a2d
∇a3d
1∗1∗1=1
1∗1∗2=2
1∗1∗1=1
1∗0∗2=0
which matches the result ∇ad=[1,2,1,0]. And we can easily verify it on tensor b.
Conclusion
Intuitively, the clamp function will cut the gradient chain. Since in chain rule, we multiply the partial derivatives together, once a 0 appeared in this chain, this multiplicative propagation is gone (except in additive derivatives).