PyTorch Clamp And The Gradient

Last updated on March 22, 2026 pm

1 From an Example

Let’s observe the following example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
>>> import torch
>>> a = torch.tensor([1, 2, 3, 4], requires_grad=True, dtype=torch.float32)
>>> a
tensor([1., 2., 3., 4.], requires_grad=True)
>>> b = torch.tensor([1, 2, 1, 2], requires_grad=True, dtype=torch.float32)
>>> b
tensor([1., 2., 1., 2.], requires_grad=True)
>>> c = torch.clamp(a * b, 1, 5)
>>> c
tensor([1., 4., 3., 5.], grad_fn=<ClampBackward1>)
>>> d = c.sum()
>>> d
tensor(13., grad_fn=<SumBackward0>)
>>> d.backward()
>>> b.grad
tensor([1., 2., 3., 0.])
>>> a.grad
tensor([1., 2., 1., 0.])

In this example, we created to tensor

a=[1,2,3,4]b=[1,2,1,2]a = [1, 2, 3, 4]\quad b=[1,2,1,2]

and these two tensors are element-wise multiplied but clamped by torch.clamp. Let’s recalculate the chain rule of this computational graph step by step.

2 Mathematical Derivations

What we want the PyTorch to calculate are gradients of tensor a and tensor b, we have

dai=dcicixixiai\frac{\partial d}{\partial a_i} = \frac{\partial d}{\partial c_i}\frac{\partial c_i}{\partial x_i}\frac{\partial x_i}{\partial a_i}

where x=abx = a * b is the element-wise multiplication.

We have

xiai=(aibi)ai=bicixi=I(not clamp)dci=jcjci=1\frac{\partial x_i}{\partial a_i} = \frac{\partial (a_ib_i)}{\partial a_i} = b_i \quad \frac{\partial c_i}{\partial x_i} = \mathbb{I}(\text{not clamp})\quad \frac{\partial d}{\partial c_i} = \frac{\sum_j c_j}{\partial c_i} = 1

Here, the clamp function has slop 1 if the value is not clamped, and the slop becomes 0 if clamped (because the value doesn’t change after that). Hence, the partial derivative of the clamp function is effectively an indicator function of whether this value is clamped.

Now we can list our gradient value

a0d\nabla_{a_0}d a1d\nabla_{a_1}d a2d\nabla_{a_2}d a3d\nabla_{a_3}d
111=11*1*1=1 112=21*1*2=2 111=11*1*1=1 102=01*0*2=0

which matches the result ad=[1,2,1,0]\nabla_{a}d=[1,2,1,0]. And we can easily verify it on tensor b.

Conclusion

Intuitively, the clamp function will cut the gradient chain. Since in chain rule, we multiply the partial derivatives together, once a 0 appeared in this chain, this multiplicative propagation is gone (except in additive derivatives).


PyTorch Clamp And The Gradient
https://lynx-li.github.io/2026/03/22/ai_infra/torch/pytorch_clamp/
Author
Lynx Li
Posted on
March 22, 2026
Licensed under