Source code analysis for Autograd

Robin Dong 2018-08-23 12:07

Autogradis a convenient tool to automatically differentiate native Python and Numpy code.

Let’s look at an example first:

import autograd.numpy as np
from autograd import grad

def f(x):
  return x * x + 1

grad_f = grad(f)
print(grad_f(1.6))

The result is 3.2

f(x) = sqaure(x) + 1, its derivative is 2*x, so the result is correct.

Function grad() actually return a ‘function object’, which is ‘grad_f’. When we call grad_f(1.6), it will ‘trace’ f(x) by:




The ‘fun’ argument is our f(x) function.



In ‘trace()’, it acutually called f() without ‘x’ but a ArrayBox object. The ArrayBox object has two purposes:

1. Go through all the operations in f() along with ‘x’, so it chould get the real result of f(x)
2. Get all the corresponding gradients of operations in f()

ArrayBox class has already override all the basic arithmetic operations, such as add/sustract/multiply/divide/square. Therefore it can catch all the operations in f(x).




After catching all the operations, ArrayBox could lookup thegradients table to get all corresponding gradients, and using chain rule get final gradient result.

Thegradients table is showed as below:



Otherwise, Autograd have other tricks to complete its work. Take function wrapper ‘@primitive’ as an example. This decorator make sure users could add new custom-defined-operation into Autograd.

The source code of Autograd is nice and neat. Its examples include fully-connected-network, CNN, even RNN. Let’s take a glimpse of the implement of Adam optimizer of Autograd to feel its concise code style:



[返回] [原文链接]