owned this note
owned this note
Published
Linked with GitHub
# A Modern Proof of the Adjoint Method for Neural ODE
The authors presented a proof alternative to the original one proposed by Pontryagin et al. in 1962 that is short and easy to follow.
Let $\textbf{z}(t)$ follow the differential equation
\begin{equation}
\frac{d\textbf{z}(t)}{dt}=f(\textbf{z}(t), t,\theta),
\end{equation} where $\theta$ are the parameters. We will prove that if we define an adjoint state
\begin{equation}
\textbf{a}(t)=\frac{dL}{d\textbf{z}(t)}
\end{equation} then it follows the differential equation
\begin{equation}
\frac{d\textbf{a}(t)}{dt}=-\textbf{a}(t)\frac{\partial f(\textbf{z}(t), t, \theta)}{\partial \textbf{z}(t)}
\end{equation}
The adjoint state is the gradient w.r.t. the hidden state at a specified time $t$. In standard neural networks, these are computed by chain rule:
\begin{equation}
\frac{dL}{d\textbf{h}_{t}}=\frac{dL}{d\textbf{h}_{t+1}}\frac{d\textbf{h}_{t+1}}{d\textbf{h}_{t}}
\end{equation}
With a continuous hidden state, we can write the transformation after an $\epsilon$ change in time as
\begin{equation}
\textbf{z}(t+\epsilon)=\int_{t}^{t+\epsilon}f(\textbf{z}(t), t,\theta)dt+\textbf{z}(t)=T_{\epsilon}(\textbf{z}(t), t)
\end{equation} and chain rule can also be applied
\begin{equation}
\frac{dL}{d\textbf{z}(t)} = \frac{dL}{d\textbf{z}(t+\epsilon)}\frac{d\textbf{z}(t+\epsilon)}{d\textbf{z}(t)}
\end{equation} or
\begin{equation}
\textbf{a}(t) = \textbf{a}(t+\epsilon)\frac{dT_{\epsilon}(\textbf{z}(t), t)}{d\textbf{z}(t)}
\end{equation}
The proof comes from the definition of derivative:
\begin{align}
\frac{d\textbf{a}(t)}{dt} &= \lim_{\epsilon\rightarrow 0}\frac{\textbf{a}(t+\epsilon)-\textbf{a}(t)}{\epsilon}\\
&= \lim_{\epsilon\rightarrow 0}\frac{\textbf{a}(t+\epsilon)\left(1-\frac{dT_{\epsilon}(\textbf{z}(t), t)}{d\textbf{z}(t)}\right)}{\epsilon}\\
&=\lim_{\epsilon\rightarrow 0}\frac{\textbf{a}(t+\epsilon)\left(1-\frac{d\left(\textbf{z}(t)+\epsilon f(\textbf{z}(t), t, \theta)+O(\epsilon^2)\right)}{d\textbf{z}(t)}\right)}{\epsilon}\\
&=\lim_{\epsilon\rightarrow 0}-\textbf{a}(t+\epsilon)\frac{df(\textbf{z}(t), t, \theta)}{d\textbf{z}(t)} + O(\epsilon)\\
&= -\textbf{a}(t)\frac{df(\textbf{z}(t), t, \theta)}{d\textbf{z}(t)}
\end{align}
Similar to BP, ODE for the adjoint state needs to be solved *backwards* in time. We specify the constraint on the last time point, which is simply the gradient of the loss w.r.t. the last time point, and can obtain the gradients w.r.t. the hidden state at any time, including the initial value.
\begin{align}
\textbf{a}(t_N)&=\frac{dL}{d\textbf{z}(t_N)}\\
\textbf{a}(t_0)&=a(t_N)+\int_{t_0}^{t_N}\textbf{a}(t)\frac{df(\textbf{z}(t),t,\theta)}{d\textbf{z}(t)}
\end{align}
If $L$ depends also on intermediate time points $t_{1},t_2,\cdots,t_{N-1}$, etc., we can repeat the adjoint step for each of the intervals $[t_{N-1},t_{N}], [t_{N-2},t_{N-1}]$ in the backward order and sum up the obtained gradients.
We can view $\theta$ and $t$ as states with constant differential equations and write
\begin{align}
&\frac{\partial \theta(t)}{\partial t}=0 & \frac{\partial t(t)}{\partial t}=1
\end{align}
We can then combine these with $\textbf{z}$ to form an augmented state with corresponding differential equation and adjoint state,
\begin{align}
\frac{d}{dt}\begin{bmatrix}\textbf{z}\\\theta\\t\end{bmatrix}(t)&=f_{\text{aug}}([\textbf{z},\theta,t]):=\begin{bmatrix}f([\textbf{z},\theta, t])\\\textbf{0}\\1\end{bmatrix}\\
\textbf{a}_{\text{aug}} &:= \begin{bmatrix}\textbf{a}\\\textbf{a}_{\theta}\\\textbf{a}_{t}\end{bmatrix}, \textbf{a}_{\theta}(t):= \frac{dL}{d\theta(t)}, \textbf{a}_{t}:=\frac{dL}{dt(t)}
\end{align}
The Jacobian of $f_{\text{aug}}$ has the form
\begin{equation}
\frac{\partial f_{\text{aug}}}{\partial[\textbf{z},\theta, t]} =
\begin{bmatrix}
\frac{\partial f}{\partial \textbf{z}} &\frac{\partial f}{\partial \theta} &\frac{\partial f}{\partial t}\\
\textbf{0} &\textbf{0} &\textbf{0}\\
\textbf{0} &\textbf{0} &\textbf{0}
\end{bmatrix}
\end{equation}
Recall
\begin{equation}
\frac{d\textbf{a}_{\text{aug}}(t)}{dt}=-[\textbf{a}(t), \textbf{a}_{\theta}(t),\textbf{a}_{t}(t)]\frac{\partial f_{\text{aug}}}{\partial[\textbf{z}, \theta, t]}(t)=-[\textbf{a}\frac{\partial f}{\partial \textbf{z}}, \textbf{a}\frac{\partial f}{\partial \theta}, \textbf{a}\frac{\partial f}{\partial t}](t)
\end{equation}
The first element is the adjoint differential equation, as expected. The second element can be used to obtain the total gradient with respect to the parameters, by integrating over the full interval.
\begin{equation}
\frac{dL}{d\theta}=\int_{t_N}^{t_0}\textbf{a}(t)\frac{\partial f(\textbf{z}(t), t, \theta)}{\partial \theta}dt
\end{equation}
Finally, we also get gradients w.r.t. $t_0$ and $t_N$, the start and end of the integration interval.
\begin{align}
\frac{dL}{dt_{N}} &=-\textbf{a}(t)\frac{\partial f(\textbf{z}(t), t, \theta)}{\partial t}|_{t=t_N}\\
\frac{dL}{dt_{0}} &=\textbf{a}(t)\frac{\partial f(\textbf{z}(t), t, \theta)}{\partial t}|_{t=t_0}
\end{align}