When using PyTorch to train a neural network model, an important step is backpropagation like this:
loss = criterion(y_pred, y)
loss.backward()
The gradient of weight tensors are calculated here though none of the gradients or the weight matrix ever appear. The mathematical formulation of backpropagation is clear. In this post, we try to understand how PyTorch calculate gradient with this loss.backward()
function from the perspective of programming.
Some of PyTorch source code is based on C++. We will mostly focus on the Python part and give one example of C++ to help us understand.
Background Knowledge
backward()
method
PyTorch uses the autograd
package for automatic differentiation. For a tensor y
, we can calculate the gradient with respect to input with two methods. They are equal:
y.backward()torch.autograd.backward(y)
After we do the .backward()
, we can check the gradient value using:
x.grad()
Calculation Graph
PyTorch generates a Dynamic Computation Graph when the forward function of network is called. We borrow a toy example from here.
If our model is like this and you actually run that (not just define a model):
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)
c = a + b
d = torch.tensor(4.0, requires_grad=True)
e = c * d
e.backward()
A calculation graph is generated as you run the forward pass. The corresponding graph of this model is like this:
The nodes represent the tensors and the circles represent the operation. What’s more important, when the forward calculation graph is generated, a backward graph is simultaneously generated to calculat the gradient. The whole picture is like this:
We can see there are some notions in yellow rectangles. They actually implement the backpropagation operation. We will talk about them later.
Gradient Calculation Process
We use the abovementioned example to explain. First let’s look at this tensor e
:
tensor(20., grad_fn=<MulBackward0>)
You find an attribute called grad_fn
. Based on the documentationm grad_fn
stores the reference to the backward propagation function of the tensor. That's confusing.
To be straightforward, grad_fn
stores the according backpropagation method based on how the tensor (e
here) is calculated in the forward pass. In this case e = c * d
, e
is generated through multiplication. So grad_fn
here is MulBackward0
, which means it is a backpropagation operation for multiplication
.
grad_fn
has a method called next_functions
, we check e.grad_fn.next_functions
, it returns a tuple of tuple:
((<AddBackward0 at 0x268c6d3e668>, 0), (<AccumulateGrad at 0x268c6d3e588>, 0))
If you remember the meaning of MulBackward0
, you will notice here AddBackward0
represents an addition operation in forward pass. Considering the fact that e = (a+b) * d
, the pattern is clear: grad_fn
traverse all members in its next_functions
to use a chain structure in the gradient calculation process.
In this case, to calculate gradient of e
with respect to input a
, it need to both calculate the gradients of multiplication operation and then the addition operation.
AccumulateGrad
is similar to MulBackward0
and AddBackward0
, it belongs to the original input. In this case d
is an original input which is not c
calculated as a+b
.
And AccumulateGrad
has two methods called next_functions
and variable
.
If we go further to c = a + b
, we pick AddBackward0
and take a look at its next_functions
:
ag = e.grad_fn.next_functions[0][0] #<AddBackward0 at 0x268c6d3e668>
ag.next_functions# ((<AccumulateGrad at 0x268c6d3e978>, 0),
# (<AccumulateGrad at 0x268c6d3e898>, 0))
Since c = a + b
and a
and b
are original inputs, the next_function is AccumulateGrad
for a
and b
as same as that for d
.
We then check its variable
:
ag.variable
It will returns an error: AttributeError: 'AddBackward0' object has no attribute 'variable'
. The reason is that here c = a + b
. c
is not an original input. So it's not stored as variable
.
However, if we go with AccumulateGrad
(d
):
ag = e.grad_fn.next_functions[1][0] #<AddBackward0 at 0x268c6d3e668>
ag.variable
It returns the variable:
tensor(4., requires_grad=True)
This should be familiar! It’s just the definition of d
(d = torch.tensor(4.0, requires_grad=True)
).
Similarly, if we go with the AccumulateGrad
in a
and b
, we should also see familiar tensors. We omit that here.
Now we have see the whole process from the output e
to input a
, b
, and d
.
To sum up, when we call e.backward()
to calculate the gradient, Pytorch first calculate the derivative of e
for variables based on the traversal of next_functions
. If the next_function is AccumulateGrad
, it means that's an original input. The calculated gradient is stored in the variable's .grad
attribute.
Pytorch keeps traversing the computation graph till it reaches all the original inputs. At that moment, the gradients of all inputs are updated. During this process, the chain rule is used: First derivative of multiplication (de/dc) is calculated, then derivative of addtion (dc/da) is calculated. Then we get a’s gradient as (de/da).
Derivative Calculation
Till now, we already understand the process of gradient calculation. But there’s still one question, how is each derivative get calculated? The answer is extremly simple: It is stored in object such like MulBackward0
and AddBackward0
class. For example, the derivative of multiplication e = c * d
is de/dc = d
, MulBackward0
object knows the value of e
, c
, and d
. It directly return d
as the output derivative. It does not do any real derivative calculation here.
A related question is how the derivative of activation function calculated? It’s just as same as the MulBackward0
. The class directly return the needed value without calculation. We use sin()
activation as an example. We put an image of its C++ implementation here.
We can see that auto grad_result = grad * self.cos()
, which returns cos()
as the output. No calculation, no intelligence, just memory.
Conclusion
In this post, we introduce how the gradient is generated in Pytorch on the level of programming. We also show how exactly each derivative is calculated.
References
To write this post, I read several posts related to this topic. They are really helpful. You may refer to them to acquire a better understanding.