$$
\def\E{{\mathbb{E}}}
\def\P{{\mathbb{P}}}
\def\R{{\mathbb{R}}}
\def\L{{\mathcal{L}}}
\def\M{{\mathcal{M}}}
\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}
\DeclareMathOperator*{\div}{\mathrm{div}}
\def\rmd{{\mathrm{d}}}
\def\rot{{\mathrm{rot}}}
\def\trans{{\mathrm{trans}}}
\def\SO{{\mathrm{SO}(3)}}
\def\SE{{\mathrm{SE}(3)}}
$$
# Flow matching
## Ordinary diferential equation
- $\frac{\partial}{\partial t} \phi_t(x) = u_t(\phi_t(x))$
- $p_t = [\phi_t]_* p_0$
- Assuming the vector field $u_t$ to be C-1 and bounded, then the flow $\phi$ is C-1 diffeomorphism (Mathieu et al. 2019).
## marginal vf induced by conditional vf
The following vector field $u_t$ is defined by _marginalising_ over the _conditional_ vector field $u_t(\cdot|x_1)$ which generates the conditional proability path $p_t(\cdot|x_1)$
$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1)q(x_1)}{p_t(x)} dx_1$
The _marginal_ vector field $u_t$ generates the marginal probability path $p_t$ since it satisfies the continuity equation:
$$
\frac{\partial p_t(x)}{\partial t}
= \int \frac{\partial}{\partial t} p_t(x|x_1) q(x_1) dx_1 \\
\quad = -\int \div (p_t(x|x_1) u_t(x|x_1)) q(x_1) dx_1 \\
\quad = -\int \div (p_t(x|x_1) u_t(x|x_1) q(x_1)) dx_1 \\
\quad = -\div (u_t(x) p_t(x))
$$
## Conditional flow matching
<!-- -->
Here, we aim to generates the probability path $p_t(\cdot|y)$ conditioned on some observation $y$ given a likelihood $p(y|x_1)$.
### Conditional vector field
We want to find a vector field $u_t(\cdot|y)$ conditioned on this observation $y$ such that it satisfies the continuity equation:
$$
\frac{\partial p_t(x|y)}{\partial t}
= -\div \left(u_t(x|y) p_t(x|y)\right)
$$
We postulate that $u_t(\cdot|y) = u_t + \nabla_x \log p(y|x)$ (or perhaps with some additional rescaling as a function of $t$).
Thus
$$
\frac{\partial p_t(x|y)}{\partial t}
= -\div \left((u_t(x) + \nabla_x \log p(y|x)) p_t(x|y)\right) \\
= - p_t(x|y) \div(u_t) - \langle u_t, \nabla_x p_t(x|y) \rangle - p_t(x|y) \div(\nabla_x \log p(y|x)) - \langle \nabla_x \log p(y|x), \nabla_x p_t(x|y) \rangle \\
= - p_t(x|y) \div(u_t) - p_t(x|y)\langle u_t, \nabla_x \log p_t(x|y) \rangle - p_t(x|y) \left(\Delta_x \log p(y|x)) + \| \nabla_x \log p(y|x)\|^2 + \langle \nabla \log p(y|x), \nabla_x \log p_t(x) \rangle\right)
$$
Now starting with Bayes rules, $p_t(x|y) = p_t(y|x)p_t(x)/p(y)$, we have
$$
\frac{\partial p_t(x|y)}{\partial t}
= \frac{p_t(y|x)}{p(y)} \frac{\partial p_t(x)}{\partial t} + \frac{p_t(x)}{p(y)} \frac{\partial p_t(y|x)}{\partial t} \\
= - \frac{p_t(x|y)}{p_t(x)} \div(u_t p_t)(x) + \frac{p_t(x|y)}{p_t(y|x)} \frac{\partial p_t(y|x)}{\partial t} \\
= - p_t(x|y) \div(u_t) - p_t(x|y) \langle u_t, \nabla \log p_t(x) \rangle + p_t(x|y) \frac{\partial \log p_t(y|x)}{\partial t} \\
= - p_t(x|y) \div(u_t) - p_t(x|y) \langle u_t, \nabla \log p_t(x|y) \rangle + p_t(x|y) \langle u_t, \nabla \log p_t(y|x) \rangle + p_t(x|y) \frac{\partial \log p_t(y|x)}{\partial t} \\
= - \div(p_t(\cdot|y) u_t)(x) + \frac{p_t(y|x)p_t(x)}{p(y)} \langle u_t, \nabla \log p_t(y|x) \rangle + p_t(x|y) \frac{\partial \log p_t(y|x)}{\partial t} \\
= \ ... \\
= -\div\left( (u_t + v_t(\cdot\ ;y)) p_t(\cdot|y) \right)(x)
$$
<!-- = - p_t(x|y) \div(u_t) - p_t(x|y) \langle u_t, \nabla \log p_t(x|y) \rangle + p_t(x|y) \langle u_t, \nabla \log p_t(y|x) \rangle \\ -->
<!-- We have that
$$
\frac{\partial p_t(y, x)}{\partial t}
= \frac{\partial }{\partial t} \int p(y|x_1) p(x_t, x_1) \rmd x_1
= \int p(y|x_1) p(x_1) \frac{\partial }{\partial t} p(x_t|x_1) \rmd x_1
= -\int p(y|x_1) p(x_1) \div (u_t(x_t) p(x_t|x_1)) \rmd x_1 \\
= - \div (\int p(y|x_1) p(x_1) p(x_t|x_1) u_t(x_t) \rmd x_1)
= - \div (p(y,x_t) u_t(x_t))
$$ -->
We have that
$$
\frac{\partial p_t(y|x)}{\partial t}
= \frac{\partial }{\partial t} \int p(y|x_1) p(x_1|x_t) \rmd x_1
= \int p(y|x_1) \frac{\partial }{\partial t} p(x_t|x_1) \rmd x_1
= -\int p(y|x_1) \div (u_t(x_t) p(x_t|x_1)) \rmd x_1 \\
= - \div (\int p(y|x_1) p(x_t|x_1) u_t(x_t) \rmd x_1)
= - \div (p(y|x_t) u_t(x_t))
$$
Thus
$$
\frac{\partial \log p_t(y|x)}{\partial t}
= \frac{1}{p_t(y|x)} \frac{\partial p_t(y|x)}{\partial t}
= - \frac{1}{p_t(y|x)} \div (p(y|x_t) u_t(x_t))
$$
### Vector field parameterisation
It can be convienient, e.g. in protein modelling, to directly parameterise a neural network $w_\theta$ s.t. $x_1^\theta(t, x_t) \approx \E[x_1|x_t]$.
Then a parametric vector field can be constructed as suggested in Bose et al. 2023:
$v_\theta(t, x_t) \triangleq - \frac{1}{2} \frac{1}{1 - t} d(x_1^\theta(t, x_t), x_t)^2 = - \frac{1}{1 - t} \log_{x_t}(x_1^\theta(t, x_t))$.
Similarly, if the vector field is directly parameterised with a neural network $v_\theta$ then one can get an estimate of the conditional mean:
$\E[x_1|x_t] \approx x_1^\theta(t, x_t) \triangleq \exp_{x_t}((t -1) v_\theta(t, x_t))$
### Guidance term
One would like to 'guide' the generative model with a likelihood term:
- $p_t(y|x) = \int p(y|x_1) p_t(x_1|x) dx_1$ where $p(y|x_1) = \mathcal{N}(y|A x_1, \sigma_y^2)$
- Assuming the flow $\phi_t$ to be bijective, $p_t(x_1|x) = \delta_{x_1=\phi_t(x)}(x_1) \approx \delta_{x_1=x_1^\theta(t, x_t)}(x_1)$
- Thus $p_t(y|x_t) \approx \mathcal{N}\left(y|A(x_1^\theta(t, x_t)), \sigma_y^2\right)$
$\Emile$: unsure about what follows
- and $p_t(x_1|x_t) = p_t(x_t|x_1) p(x_1) / p_t(x_t)$
- and $p_t(x_t|x_1) = \mathcal{N}(s_t x_1, \sigma_t^2)$ with $s_t = t$ and $\sigma_t = 1 - t$
- Following reconstruction guidance work (e.g. Rozet and Louppe 2023, Finzi et al. 2023), a reasonnable approximation is: $p_t(x_1|x_t) \approx \mathcal{N}(x_1|x_1^\theta(t, x_t), \gamma \sigma_t^2 / s_t^2) = \mathcal{N}(x_1|x_1^\theta(t, x_t), \gamma (1 - t)^2 / t^2)$
- Thus $p_t(y|x_t) \approx \mathcal{N}\left(y|A(x_1^\theta(t, x_t)), \sigma_y^2 + \gamma \frac{(1 - t)^2}{t^2} A A^\top\right)$
- Thus $\nabla_{x_t} \log p_t(y|x_t) \approx \left(\sigma_y^2 + \gamma \frac{(1 - t)^2}{t^2} A A^\top \right)^{-1} \left(y - A(x_1^\theta(t, x_t)\right)$
### Doob's h-transform
Let's denote by $\mathbb{P}$ the distribution of the following process
$$\rmd X_t = f_t(X_t) \rmd t + \sigma_t \rmd B_t.$$
Then by using results from the Doob h-transform theory (Rogers and Williams, 2000), the process $H_t$ which additionally satisfy $X_T = x_T$ is given by
$$\rmd H_t = \left( f_t(H_t) + \sigma_t^2 \nabla \log h(H_t) \right)\rmd t + \sigma_t \rmd B_t$$
with $h(X_t) \triangleq \P_t(X_T = x_T|H_t)$ the *h-transform*.
$\Emile$: With $\sigma_t > 0$, $\nabla \log \P_t(X_T = x_T|H_t)$ is well defined, but as per the DSBM paper (after definition 1) it seems that this generalises to $\sigma_t = \sigma \rightarrow 0$?
### Going through the diffusion ODE - AndrewC
Assume $x_0$ is data and $x_1$ is noise. Assume $x_t = (1-t) x_0 + t x_1$. The equivalent SDE that gives the same marginals is
$$ dx = \frac{-1}{1-t} x dt + \sqrt{\frac{2t}{1-t}} dw$$
See https://www.overleaf.com/project/6500281c8c4a9d7736a0d556
Now use the fact that $\nabla_{x_t} \log p_t(x_t) = \E \left[ -x_1 / t | x_t \right]$ and the vector field from flow matching is $\nu_t(x) = \E \left[x_0 - x_1 | x_t \right]$ this gives
$$ \nabla_{x_t} \log p_t(x_t) = \frac{(1-t) \nu_t(x_t) - x_t}{t} $$
Now write down the flow matching ODE from the diffusion perspective
$$ dx = \left(f(x,t) - \frac{1}{2} g(t)^2 \nabla_x \log p_t(x)\right)dt$$
$$ dx = \left( \frac{-1}{1-t}x - \frac{1}{2} \frac{2t}{1-t} \nabla_x \log p_t(x) \right)dt$$
Note that this matches the normal flow matching ODE integration of $\nu_t(x)$ as we can see by substituting in the relation between $\nabla_{x_t}\log p_t(x_t)$ and $\nu_t(x_t)$
$$ dx = \left( \frac{-1}{1-t} x - \frac{t}{1-t} \frac{(1-t) \nu_t(x_t) - x_t}{t} \right) dt$$
which simplifies to
$$ dx = \left( - \nu_t(x_t) \right)dt$$.
Now we know how the flow matching ODE corresponds to an ODE on a score, we can just adjust $\nabla_x \log p_t(x)$ with our guidance term, $\nabla_x \log p_t(x | y) = \nabla_x \log p_t(y | x) + \nabla_x \log p_t(x)$
$$ dx = \left( \frac{-1}{1-t}x - \frac{1}{2} \frac{2t}{1-t} \left\{ \nabla_x \log p_t(y|x) + \nabla_x \log p_t(x) \right\} \right)dt$$
From this we can find the conditional vector field $\tilde{\nu}_t(x | y)$
$$ - \tilde{\nu}_t(x_t | y) = \frac{-1}{1-t} x_t - \frac{t}{1-t} \left\{ \nabla_{x_t} \log p_t(y | x_t) + \frac{(1-t) \nu_t(x_t) - x_t}{t} \right\}$$
$$ \tilde{\nu}_t(x_t | y) = \frac{t}{1-t} \nabla_{x_t}\log p_t(y | x_t) + \nu_t(x_t)$$
### Motif-scafolding with reconstruction guidance
- We assume a likelihood of the form $p(y|x_0) = \mathrm N(y|A(x_0), \sigma_y^2) = \mathrm N(y|QMx_0, \sigma_y^2)$
with $M$ the masking linear operator selecting the motif from the full backbone, and $Q$ the matrix aligning the predicted motif and observed motifs.
- We assume that $M$ is known but not $Q$ thus we rely on a MC estimator to marginalise it out:
$p(y|x_0) = \int_{\SO} \mathrm N(y|QMx_0, \sigma_y^2) p(Q) \rmd Q$.
- In Wu et al. 2023, they use $p = \mathcal U (\SO)$
- It has also been proposed to use $p = \delta_{\hat Q}$ where $\hat Q$ is the matrix that best align the predicted motif with the observed motif as given by the Kabsch algorithm
- We suggest to use $p = \mathrm N_{\text{wrapped}}({\hat Q, \sigma_Q^2})$
## Scheduling
### Flow matching (on manifolds)
Given a monotonically decreasing differentiable function $\kappa(t)$ satisfying $\kappa(0) = 1$ and $\kappa(1) = 0$, want to find a $\phi_t$ that decreases $d(\cdot, x_1)$ according to $d(\phi_t(x_0|x_1), x_1) = \kappa(t) d(x_0, x_1)$.
Hence, $\kappa(t)$ acts as a scheduler that determines the rate at which $d(\cdot |x_1)$ decreases.
We then have
- $$ u_t(x|x_1) = \frac{\rmd \log \kappa(t)}{\rmd t} d(x, x_1) \frac{\nabla d(x, x_1)}{\|\nabla d(x, x_1)\|^2}
= \frac{\rmd \log \kappa(t)}{\rmd t} \frac{\nabla d(x, x_1)^2}{2\|\nabla d(x, x_1)\|^2} \\
= \frac{\rmd \log \kappa(t)}{\rmd t} \frac{-\log_x(x_1)}{\|\nabla d(x, x_1)\|^2}
= \frac{- \rmd \log \kappa(t)}{\rmd t} \log_x(x_1). $$
- Interpolation: $x_t = \phi^\rot(x_0|x_1) = I^\rot(x_0, x_1, t) = \exp_{x_1}(\kappa(t)\log_{x_1}(x_0)) = \exp_{x_0}((1-\kappa(t))\log_{x_0}(x_1))$
- Euler update: $x_{t+\Delta_t} \approx \exp_{x_t}(-\frac{\rmd \log \kappa(t)}{\rmd t}*\Delta_t* \log_{x_t}(x_1)) \approx \exp_{x_t}(-\frac{\rmd \log \kappa(t)}{\rmd t}*\Delta_t* \log_{x_t}(\hat x_1(t, x_t)))$
- examples:
- _linear_: $\kappa(t) = 1 - t \Rightarrow \frac{\rmd \log \kappa(t)}{\rmd t} = \frac{-1}{1 - t}$, $\kappa(0)=1, \kappa(1)=0$
- _exponential_: $\kappa(t) = e^{-c t} \Rightarrow \frac{\rmd \log \kappa(t)}{\rmd t} = -c$, $\kappa(0)=1, \kappa(1)=e^{-c}\rightarrow_0$ with $c \rightarrow \infty$
### Bose et. al. 2023
Scaling up the vector field via $c(t)$:
$$dR_t = c(t) u_\theta(t, X_t) \rmd t + \gamma(t) \rmd B_t$$
### Albergo et al. 2023 (assuming Euclidean support)
- _Barycentric stochatic interpolant_: Given $K+1$ pdfs $\{\rho_k\}_k$, and $x_k \sim \rho_k$, the interpolant $\alpha=(\alpha_0, \dots, \alpha_K) \in \Delta^K$: $x(\alpha) = I(\alpha, x) = \alpha^\top x = \sum_{k=0}^K \alpha_k x_k$.
- $\partial_t I(\alpha, x) = \sum_{k=0}^K \dot{\alpha}_k(t) x_k$
- $\E[\partial_t I(\alpha, x(\alpha))|x(\alpha)=x] = \sum_{k=0}^K \dot{\alpha}_k(t) \E[x_k|x(\alpha)=x]$
- _Continuity equations_: The probability distribution of the stochatic interpolant $x(\alpha)$ has a density $\rho(\alpha, x)$ which satisfy $K+1$ equations:
$$ \partial_{\alpha_k} \rho(\alpha, x) + \nabla_x \cdot (u_k(\alpha, x) \rho(\alpha, x)) = 0 $$
with $u_k(\alpha, x) \triangleq \E[x_k | x(\alpha) = x].$
- Moreover: $\E[x_t|x_t=x] = x \Leftrightarrow \sum_k \alpha_k(t) u_k(t, x) = x \\ \Rightarrow u_0(t, x) = 1/\alpha_0(t) (x - \sum_{k>1} \alpha_k(t) u_k(t, x))$
- _Optimisation of vector fields_: Each $u_k$ is the unique minimiser of: $\mathcal{L}_k(\hat{u}_k) = \int_{\Delta^k} \E \| \hat{u}_k(\alpha, x(\alpha)) - x_k \|^2 \rmd \alpha$
- _Transport equation_:
- Let $\alpha: (0, 1) \rightarrow \Delta_K$ with $\alpha(0)=e_i$ and $\alpha(1)=e_j$.
- The barycentric interpolant $x(\alpha(t))$ has density $\rho(\alpha(t), x)$ and satisfy the continuity equation with boundary conditions $\rho(t=0, \cdot) = \rho_i$ and $\rho(t=1, \cdot) = \rho_j$
- with velocity field $u(t, x) = \sum_{k=0}^K \dot{\alpha}_k(t) u_k(\alpha(t), x)$.
- $\dot{x_t} = \frac{\rmd \phi_t(x_0)}{\rmd t} = u(t, x_t)$
- _Optimisation of path $\alpha: (0, 1) \rightarrow \Delta_K$_: The solution to $\mathcal{C}(\alpha)$ gives the transport with least path length in Wasserstein-2 metric (assuming a drift linear in $u_k$s):
$$\mathcal{C}(\alpha) = \min_{\hat{\alpha}} \int_0^1 \E \left[ \| \sum_{k=0}^K \dot{\alpha}_k(t) u_k(\hat{\alpha}(t), x(\hat{\alpha}(t))) \|^2 \right] \rmd t$$
### Interpolating between noise to data distribution over $\SE^N / ~\R^3 = \SO^N \ltimes \R^{3 \times N} / ~\R^3$:
- Interpolations:
- translations
- $x_t = \phi^\trans_t(x_0|x_1) = I^\trans(x_0, x_1, t) = (1 - t)x_0 + tx_1$
- $\frac{\rmd}{\rmd t} \phi^\trans_t(x_0|x_1) = \partial_t I^\trans(x_0, x_1, t) = x_1 - x_0$
- $u^\trans(x_t|x_1) = \frac{x_1 - x_t}{1 - t} = \frac{x_1 - (1 - t)x_0 - tx_1}{1-t} = \frac{(1 - t)x_1 - (1 - t)x_0}{1-t} = x_1 - x_0$
- $u_\theta^\trans = \frac{\hat{x}^\theta_1 - x_t}{1 - t}$ where $\hat{x}^\theta_1 \approx \E[x_1 | T_t]$
- $\L(\theta) = \E_{x_1, x_0} \|u_\theta^\trans(t, \phi^\trans_t(x_0|x_1)) - \frac{\rmd}{\rmd t} \phi^\trans_t(x_0|x_1)\|^2 \\
= \E_{x_1, x_0} \| \frac{\hat{x}^\theta_1 - x_t}{1 - t} - \frac{x_1 - x_t}{1-t} \|^2
= \E_{x_1, x_0} \left[\frac{1}{(1-t)^2} \|\hat{x}^\theta_1 - x_1 \|^2 \right]$
- rotations:
- $r_t = \phi^\rot(r_0|r_1) = I^\rot(r_0, r_1, t) = \exp_{r_0}(t\log_{r_0}(r_1)) = \exp_{r_1}((1-t)\log_{r_1}(r_0))$
- $u^\rot(x_t|x_1) = \frac{\log_{r_t}(r_1)}{1 - t}$
- $u_\theta^\rot = \frac{\log_{r_t}(\hat{r}^\theta_1)}{1 - t}$ where $\hat{r}^\theta_1 \approx \E[r_1 | T_t]$
- $\L(\theta) = \E_{x_1, x_0} \|u_\theta^\rot(t, \phi^\rot(x_0|x_1)) - \frac{\rmd}{\rmd t} \phi^\rot_t(x_0|x_1)\|^2 \\
= \E_{x_1, x_0} \| \frac{\log_{r_t}(\hat{r}^\theta_1)}{1 - t} - \frac{\log_{r_t}(r_1)}{1 - t} \|^2
= \E_{x_1, x_0} \left[\frac{1}{(1-t)^2} \|\log_{r_t}(\hat{r}^\theta_1) - \log_{r_t}(r_1) \|^2 \right]$
- $\Emile \Rightarrow$ has same minimiser as $\L(\theta) = \E_{x_1, x_0} \left[\frac{1}{(1-t)^2} d(\hat{r}^\theta_1, r_1)^2\right]$, worth trying minimising this?
- manifolds:
- $x(\alpha) = I(\alpha, x) = \exp_{x^\star}\left(\sum_{k=0}^K \alpha_k \log_{x^\star}(x_k)\right)$ with some arbitrary anchor $x^\star \in \M$
- e.g. $x^\star = x_0 \Rightarrow x(\alpha) = \exp_{x_0}\left(\sum_{k=1}^K \alpha_k \log_{x_0}(x_k)\right)$ and if $K=1$ we fall back on $x(t) = \exp_{x_0}\left(t \log_{x_0}(x_k)\right)$ with $\alpha(t) = (t-1, t)$.
- $\partial_t I(\alpha, x(\alpha)) = \partial_t \exp_{x_0}\left(\sum_{k=1}^K \alpha_k(t) \log_{x_0}(x_k)\right) \\
= \partial \exp_{x_0}\left(\sum_{k=1}^K \alpha_k(t) \log_{x_0}(x_k)\right) \left(\sum_{k=1}^K \dot{\alpha}_k(t) \log_{x_0}(x_k)\right)$
- $\E[\partial_t I(\alpha, x(\alpha))|x(\alpha)=x] = ?$
- $\E[\sum_{k=1}^K \dot{\alpha}_k(t) \log_{x_0}(x_k)|x(\alpha)=x] = \sum_{k=1}^K \dot{\alpha}_k(t) ~\E[\log_{x_0}(x_k)|x(\alpha)=x]
\le \sum_{k=1}^K \dot{\alpha}_k(t) \log_{x_0}(\E[x_k|x(\alpha)=x])$
- $\E[\partial_t I(\alpha, x(\alpha))|x(\alpha)=x] \approx \partial \exp_{x_0}\left(\sum_{k=1}^K \alpha_k(t) \log_{x_0}(\E[x_k|x(\alpha)=x])\right) \sum_{k=1}^K \dot{\alpha}_k(t) \log_{x_0}(\E[x_k|x(\alpha)=x])$
- PDFs: Let's define
/!\ $\Emile$, this is wrong as it would be targetting the product of the marginals and not the joint /!\
- $\rho_0 \triangleq \rho_0^\rot \otimes \rho_0^\trans$ with $\rho_0^\rot \triangleq \mathcal{U}(\SO^{\otimes N})$ and $\rho_0^\trans \triangleq \mathcal{N}(\R^{3\otimes N} / \R^3)$ (i.e. a standard normal on the quotient space where the CoM is centred.)
- $\rho_1^\rot \in \mathcal{P}_1^+(\SO^{\otimes N})$ the data distribution of rotations
- $\rho_1^\trans \in \mathcal{P}_1^+(\R^{3 \otimes N} / \R^3)$ the data distribution of translations
- Stochastic interpolant:
- We consider the simplex $\Delta^2$ where the top vertex is $\rho_0$, the left bottom one is $\rho_1^\rot$ and the right bottom one is $\rho_1^\trans$.
- At training time:
- Should sample uniformly?
- Sampling the **same** scheduling time $t \in (0, 1)$ for both rotations and translations yield a **unique** curve $\alpha(t)$ over the simplex $\Rightarrow$ This implies that at inference time we would have to use the same scheduling $\alpha(t)$
- Sampling different scheduling times $t^\rot, t^\trans \in (0, 1)$ for rotations and translations, with full support over the simplex, would allow to use any path $\alpha(t)$ at inference time!
- _Path_: We want to contruct / learn a path $\alpha: (0, 1) \rightarrow \Delta_2$ such that $\alpha(0) = e_0$ (top vertex) and $\alpha(1) = e_{12}$ (associated with $\rho_1^\rot \otimes \rho_1^\trans \triangleq \rho_1$) (bottom centre edge).
- Using a linear schedule (with same speed) for both rotations and translations would yield to straight vertical path
- Using faster schedule for rotations, would yield a (concave) path leaning to the left side of the simplex.
- Goal: learn $\alpha(t)$!
### Non-linear multimarginal interpolant
The original derivation of the amortized path choice in the multimarginal setup impliede some linearity in the interpolant.
Let's consider the barycentric interpolant of the form
$$
x(\alpha) = \sum_k \alpha_k x_k \qquad \text{with} \qquad \sum_k \alpha_k = 1
$$
The rhetoric then was there there exists $g_k(\alpha, x)$ for which
$$
g_k(\alpha, x) = \mathbb E[x_k | x(\alpha) = x]
$$
which comes from the fact that $\partial \alpha_k x(\alpha)=x_k$ and we have
$$
b(t,x) = \sum_k \dot \alpha_k(t) g_k(\alpha(t), x).
$$
However the statement can be made more general if instead of considering an interpolant that is linear above, we just make the choice
$$
x(\alpha) = I(\alpha, \bar x)
$$
for any non linear function $I$ that satisfies the right boundary conditions with respect to the interpolation coordinate $\alpha$ and where $\bar x$ is the column-wise vector of the marginal samples $x_k \sim \rho_k$, e.g. $\bar x = [x_1, .., x_K]$ where $x_k \in \mathcal X$ whatever space $\mathcal X$.
In this case, we have two choices for representing the analytic form of the velocity field. The first is to continue with $\bar x$. Here we note that more generically, an operator from the space of the coordinates $\alpha$ (call it $S$) crossed with the space $\mathcal X$ (call it $\mathbb R^d$ for now) defined as $g(\alpha, x): S \times \mathbb R^d \rightarrow \mathbb R^{K\times d}$ gives each of these marginal velocities $g_k$ at once via:
$$
g(\alpha, x) = \mathbb E[ \nabla_\alpha I(\alpha, \bar x) | x(\alpha) = x]
$$
where again $x(\alpha) = I(\alpha, x)$. We then have again the expression for $b$ holds, but should be written more generally as
$$
b(t, x) = \langle \dot \alpha(t), g(\alpha, x) \rangle
$$
where the $\langle \cdot, \cdot \rangle$ denotes the sum. Note that whatever $I$ is, it still needs to meet the boundary conditions.
Alternatively, one could learn each $g_k$ separately as
$$g_k(\alpha, x) = \mathbb E[\partial_{\alpha_k} I(\alpha, \bar x) | x(\alpha) = x]$$
#### What we need to do!
- Network architecture must take $\alpha(t) = (\alpha_0, \alpha_1, \alpha_2)$ (vector) as input. Although having independent times $(t^\trans, t^\rot)$ (`noisy_batch['so3_t']` and `noisy_batch['r3_t']`) for rotations and translation is equivalent: $\alpha(t) = f(t^\trans, t^\rot) = (1 - (t^\trans + t^\rot)/2, t^\trans/2, t^\rot/2)$.
- At training time need to sample $\alpha$ with full support over the simplex. E.g. can simply use linear schedule and independant interpolation times which is already the case when `interpolant.separate_t == True`: $t^\trans, t^\rot \sim \mathcal{U}[0,1]^{\otimes 2}$.
- Construct a neural network for parameterising the path:
- $\alpha_\psi: (0, 1) \rightarrow \Delta^2$ with boundary conditions $\alpha(0) = e_0$ (top vertex) and $\alpha(1) = e_{12}$.E.g. see Appendix B of Albergo et al. 2023.
- Alternatively we can work withing this $\alpha(t) = f(t^\trans, t^\rot)$ parameterisation. Then we need two monotonic endomorphisms (parameterised with neural networks) mapping $t: (0, 1) \rightarrow (0, 1)$, e.g. $t^\trans(s) = s + \left( \sum_n^N a_n \sin(\frac{\pi}{2}s) \right)^2$ and $t^\rot(s) = s + \left( \sum_n^N b_n \sin(\frac{\pi}{2}s) \right)^2$.
- Learn parameters $\psi$ so as to minimise $\mathcal{C}(\alpha_\psi) = \int_0^1 \E \left[ \| u_t(T(\alpha(t)) \|^2 \right] \rmd t$ with $u$ the drift/velocity: $\rmd T_t = u_t(T(\alpha(t)) \rmd t$. $\emile$ what expression??
<!-- $$\mathcal{C}(\alpha_\psi) = \int_0^1 \E \left[ \| \sum_{k=0}^K \dot{\alpha}_k(t) u_k(\hat{\alpha}(t), x(\hat{\alpha}(t))) \|^2 \right] \rmd t \\
=\int_0^1 \E \left[ \|\dot{\alpha}_0(t)/\alpha_0(t) \left((\alpha(t)) - \alpha_\trans(t) \hat{x}^\theta(\alpha(t), T(\alpha(t))) - \alpha_\rot(t) \hat{r}^\theta(\alpha(t), T(\alpha(t)) \right) \|^2 \\ + \|\dot{\alpha}_\trans(t) \hat{x}^\theta(\alpha(t), T(\alpha(t)))\|^2 +\|\dot{\alpha}_\rot(t) \hat{r}^\theta(\alpha(t), T(\alpha(t)))\|^2 \right] \rmd t$$ -->
### Chat 3/11/2023
$\Michael$ (+ edit from $\Emile$):
- $b(t,x) = \dot\alpha_0(t) \E[x_0 | x(\alpha(t)) = x] + \dot\alpha_1(t) \E[x_1 | x(\alpha(t)) = x]$
- $x = \E[x | x(\alpha(t)) = x] = \alpha_0(t) \E[x_0 | x(\alpha(t)) = x] + \alpha_1(t) \E[x_1 | x(\alpha(t)) = x]$
- $\E[x_0 | x(\alpha(t)) = x] = (x - \alpha_1(t) \E[x_1 | x(\alpha(t)) = x]) / \alpha_0(t)$
- $b(t,x) = \dot\alpha_1(t) \E[x_1 | x(\alpha(t)) = x] + \dot\alpha_0(t) (x - \alpha_1(t) \E[x_1 | x(\alpha(t)) = x]) / \alpha_0(t)$
- $\alpha^{\theta}(t)$ $\rightarrow$ how to find best $\alpha^*(t)$?
- Generally $\alpha_k: [0, 1] \rightarrow \mathrm{GL}$, i.e. $I(t, x_0, x_1) = \alpha_0(t) x_0 + \alpha_1(t) x_1$ (in a matrix-vector operator sense)
- e.g. in the scalar setting $\alpha_k(t) = \tilde\alpha_k(t) \mathrm I$ with $\tilde\alpha_k(t) \in [0, 1]$
- in our setting, we want $\alpha_0(t) = \texttt{block_diag}([\alpha_0^\trans(t) \mathrm I, \alpha_0^\rot(t) \mathrm I])$ with $\alpha_0^\trans(t), \alpha_0^\rot(t)$ $\in [0,1]$ (i.e. scalar) and boundary conditions $\alpha_0(0)=1$ and $\alpha_0(1)=0$.
- Training: sampling $\alpha = [\alpha_0, \alpha_1] = [\alpha_0, 1-\alpha_0] = [(\alpha_0^\trans, \alpha_0^\rot), (1-\alpha_0^\trans, 1-\alpha_0^\rot)]$
- sampling $\alpha_0^\trans, \alpha_0^\rot \in \mathcal{U}[0,1]^{\otimes 2}$ ?
- Learn model for (denoiser) $\E[x_1 | x(\alpha(t))] = \hat x^\theta_1(\alpha, x)$
- $L_{\eta_0} = \int_{\Delta} \E (|\eta_0(\alpha, x)|^2 - 2*x_0 \eta_0(\alpha, x) )dx d\alpha = \int_{\Delta} \E (\|\eta_0(\alpha, x) - x_0 \| ^2 )dx d\alpha$
- $\mathcal L (\hat \alpha) = \min_{\hat\alpha} \int_0^1 |b(\alpha(t), x)|^2 \rho(\alpha(t), x) dx dt = \min_{\hat\alpha} L^2(\rho)$
\begin{align}
b(t,x) &= \mathbb E[\partial_t I_t | x_t = x] \leftarrow \hat b(t,x) \\
&= \dot\alpha(t) \mathbb E [x_0 | x_t] + \dot \beta(t) \mathbb E[x_1 | x_t]
\end{align}
Emile is learning $E[x_1 | x_t]$ right?
Loss for $\hat \alpha(t)$?