# Error Propagation with autodiff

Klaus Reygers, 2021

## Automatic differentiation (AD)

Automatic differentiation (autodiff, AD) is an efficient and numerically stable way to calculate dervatives on a computer. The algorithm requires at most a small constant factor more arithmetic operations than the original program.

Automatic differentiation is distinct from **symbolic differentiation** and **numerical differentiation**.

The basic idea is to supplement the standard mathematical functions so that in addition to the function value also the derivative is calculated. The derivative of a composite function (a function representing a sequence of primitive operations which have specified routines for computing derivatives) is then obtained by applying the chain rule repeatedly.

Training a neural network through backpropagtion is a typical application of autodiff (TensorFlow, PyTorch, ...). 

Links:
* https://en.wikipedia.org/wiki/Automatic_differentiation
* https://www.cs.toronto.edu/~rgrosse/courses/csc321_2018/slides/lec10.pdf

## Classic Gaussian error propagation

In [8]:
from sympy import *
from IPython.display import display, Latex

In [9]:
def gaussian_error_propagation(f, vars):
    """
    f: formula (sympy expression)
    vars: list of independent variables and corresponding uncertainties 
    [(x1, sigma_x1), (x2, sigma_x2), ...]
    """
    sum = S(0) # empty sympy expression
    for (x, sigma) in vars:
        sum += diff(f, x)**2 * sigma**2 
    return sqrt(simplify(sum))

Show usage for a simple example: Volume of a cylinder with radius $r$ and height $h$:

In [10]:
r, h, sigma_r, sigma_h = symbols('r, h, sigma_r, sigma_h', positive=True)
V = pi * r**2 * h # volume of a cylinder

In [11]:
sigma_V = gaussian_error_propagation(V, [(r, sigma_r), (h, sigma_h)])
display(Latex(f"$V = {latex(V)}, \, \sigma_V = {latex(sigma_V)}$"))

<IPython.core.display.Latex object>

Plug in some numbers and print the calculated volume with its uncertaity:

In [12]:
r_meas = 3 # cm
sigma_r_meas = 0.1 # cm
h_meas = 5 # cm
sigma_h_meas = 0.1 # cm

In [13]:
central_value = V.subs([(r,r_meas), (h, h_meas)]).evalf()
sigma = sigma_V.subs([(r, r_meas), (sigma_r, sigma_r_meas), (h, h_meas), (sigma_h, sigma_h_meas)]).evalf()
display(Latex(f"$$V = ({central_value:0.1f} \pm {sigma:.1f}) \, \mathrm{{cm}}^3$$"))

<IPython.core.display.Latex object>

## Error propagation with autodiff

Links
* https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html
* https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
* http://theoryandpractice.org/intro-exp-phys-book/error-propagation/error_propagation_with_jax.html
* https://www.youtube.com/watch?v=wG_nF1awSSY

In [2]:
from jax import grad, jacfwd
import jax.numpy as jnp

In [4]:
def error_prop_jax_gen(f,x,dx):
    jac = jacfwd(f)
    return jnp.sqrt(jnp.sum(jnp.power(jac(x)*dx,2)))

In [5]:
# volume of a cylinder with (x[0] = radius, x[1] = height)
def f(x):
    return jnp.pi * x[1] * x[0]**2

In [11]:
x = jnp.array([3.,5.])
dx = jnp.array([0.1,0.1])
print (f"V = {f(x):0.1f} +/- {error_prop_jax_gen(f, x, dx):0.1f} cm**3")

V = 141.4 +/- 9.8 cm**3
