The goal, is to learn the latent segmentation masks which would be guided by the measurements.
Let $X(t)\in\mathbb{R}^{C\times D\times H\times W}$ denote a 3D frame containing the observations from the electrolyser at time $t$. Let $Y(t)\in\mathbb{R}^{C'}$ denote the corresponding measurement vector at time $t$ which contains a set of $C'$ measurements. A slice of the 3D frame at time $t$ and coordinate $z$ is denoted as $X(z,t)\in\mathbb{R}^{C\times H\times W}$ .
# Setup
At each time $t$, we receive a 2D slice of the 3D observations at coordinate $z_t$, $X(z_t,t)\in\mathbb{R}^{C\times H\times W}$ along with the corresponding measurement $Y(t)$. The goal is to learn a good latent segmentation $Z(z,t)$ for the observation $X(z_t,t)$ that would yield a correct measurement $\hat{Y}(t)$ as close as possible to the true measurement $Y(t)$.
In this case:
$$
C = 69, H=193, W=501
$$
# Joint modeling of measurements and observations
$$
\begin{align}
P(X_1,\dots,X_T,Y_1,\dots,Y_T) & = P(X_{\leq T})P(Y_{\leq T}|X_{\leq T})
\end{align}
$$
### Learning the measurements from the inputs
We're interested in the following quantity:
$$
P(Y_{\leq T}|X_{\leq T})
$$
We can write it as:
$$
\begin{align}
P(Y_{\leq T}|X_{\leq T}) & = \int P(Y_{\leq T}|X_{\leq T}, Z_{\leq T})P(Z_{\leq T}|X_{\leq T})dZ_{\leq T}
\end{align}
$$
But:
$$
P(Z_{\leq T}|X_{\leq T}) = \prod_{t=1}^T P(Z_t|Z_{<t}, X_{\leq T})
$$
And we suppose that $Z_t$ do not depend on $X_{>t}$, so we can assume that:
$$
P(Z_t|Z_{<t}, X_{\leq T}) = P(Z_t|Z_{<t}, X_{\leq t})
$$
which translates to:
$$
P(Y_{\leq T}|X_{\leq T}) = \int P(Y_{\leq T}|X_{\leq T}, Z_{\leq T})\prod_{t=1}^T P(Z_t|Z_{<t}, X_{\leq t})dZ_{\leq T}
$$
We can also use the same above argument to write:
$$
P(Y_{\leq T}|X_{\leq T}) = \int \prod_{t=1}^T P(Y_t|Y_{<t}, X_{\leq t}, Z_{\leq t}) P(Z_t|Z_{<t}, X_{\leq t})dZ_{\leq T}
$$
In this case, the latent variables $Z_{\leq t}$ represent the segmentations and can be drawn from a categorical distribution whose parameters are learned by a neural network.
### Learning the missing data
We also know that we do not have access to the full 3D frame at each time step but only some slices. So, we can further decompose $X_t$ into two components:
- An observable one (which would be the 2D slice in our case): $X_t^o$
- A missing one (which we would need to infer): $X_t^{\bar{o}}$
In that case, we can write the joint probability for the observations as :
$$
\begin{align}
P(X_{\leq T}) & = P(X^o_{\leq T}, X^{\bar{o}}_{\leq T}) \\
P(X_{\leq T}) & = P(X^{\bar{o}}_T| X^o_{\leq T}, X^{\bar{o}}_{<T})P(X^o_{\leq T}, X^{\bar{o}}_{<T}) \\
P(X_{\leq T}) & = P(X^{\bar{o}}_T| X^o_{\leq T}, X^{\bar{o}}_{<T})P(X^{\bar{o}}_{T-1}| X^o_{\leq T}, X^{\bar{o}}_{<T-1})P(X^o_{\leq T}, X^{\bar{o}}_{<T-1}) \\
& \vdots \\
P(X_{\leq T}) & = P(X^o_{\leq T})\prod_{t=1}^T P(X^{\bar{o}}_t| X^o_{\leq T}, X^{\bar{o}}_{<t})
\end{align}
$$
### Putting everything together (the VAE way)
If we only care about inferring the missing observations from the observed ones (without generating the already observed ones too), we can instead just learn the following conditional distribution:
$$
P(Y_{\leq T}, X^{\bar{o}}_{\leq T}|X^{o}_{\leq T})=P(Y_{\leq T}|X_{\leq T})P(X^{\bar{o}}_{\leq T}|X^{o}_{\leq T})
$$
Which translates to the following log-probability formulation:
$$
\begin{align}
\log P(Y_{\leq T}, X^{\bar{o}}_{\leq T}|X^{o}_{\leq T}) & =\log P(Y_{\leq T}|X_{\leq T})+\log P(X^{\bar{o}}_{\leq T}|X^{o}_{\leq T}) \\
& = \log\left(\int \prod_{t=1}^T P(Y_t|Y_{<t}, X_{\leq t}, Z_{\leq t}) P(Z_t|Z_{<t}, X_{\leq t})dZ_{\leq T}\right)+\sum_{t=1}^T\log P(X^{\bar{o}}_t| X^o_{\leq T}, X^{\bar{o}}_{<t})
\end{align}
$$
We can derive the following lower bound (Jensen + Importance Sampling) \[1, 2\]:
$$
\begin{align}
\log\left(\int \prod_{t=1}^T P(Y_t|Y_{<t}, X_{\leq t}, Z_{\leq t}) P(Z_t|Z_{<t}, X_{\leq t})dZ_{\leq T}\right) & \geq \int q_{\phi}(Z_{\leq T}|X_{\leq T}, Y_{\leq T})\left(\sum_{t=1}^T\log P(Y_t|Y_{<t}, X_{\leq t}, Z_{\leq t})+\log P(Z_t|Z_{<t}, X_{\leq t}) - \log q_{\phi}(Z_{\leq T}|X_{\leq T}, Y_{\leq T})\right) dZ_{\leq T} \\
& \geq \mathbb{E}_{q_{\phi}(Z_{\leq T}|X_{\leq T}, Y_{\leq T})}\left(\sum_{t=1}^T\left(\log P(Y_t|Y_{<t}, X_{\leq t}, Z_{\leq t})+\log P(Z_t|Z_{<t}, X_{\leq t})\right) - \log q_{\phi}(Z_{\leq T}|X_{\leq T}, Y_{\leq T})\right)
\end{align}
$$
Then:
$$
\begin{align}
\log\left(\int \prod_{t=1}^T P(Y_t|Y_{<t}, X_{\leq t}, Z_{\leq t}) P(Z_t|Z_{<t}, X_{\leq t})dZ_{\leq T}\right) & \geq \mathbb{E}_{q_{\phi}(Z_{\leq T}|X_{\leq T}, Y_{\leq T})}\left(\sum_{t=1}^T\log P(Y_t|Y_{<t}, X_{\leq t}, Z_{\leq t})\right)-KL\left(q_{\phi}(Z_{\leq T}|X_{\leq T}, Y_{\leq T})||\sum_{t=1}^T\log P(Z_t|Z_{<t}, X_{\leq t})\right)
\end{align}
$$
Here the posterior distribution $q_{\phi}(Z_{\leq T}|X_{\leq T}, Y_{\leq T})$ from which the latent segmentations $Z_{\leq T}$ are going to be sampled will follow a categorical distribution. Luckily, we can use the Gumbe-Softmax categorical parametrization for that \[4\] and we can backprop through it. Alternatively, if we want to use probability distributions in the segmentations instead of one-hot vectors, we can use Dirichlet distributions (proposed by Yoshua).
# References
\[1\] Artidoro Pagnoni, Kevin Liu, Shangyan Li: “Conditional Variational Autoencoder for Neural Machine Translation”, 2018; [http://arxiv.org/abs/1812.04405 arXiv:1812.04405].
\[2\] Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron Courville, Yoshua Bengio: “A Recurrent Latent Variable Model for Sequential Data”, 2015; [http://arxiv.org/abs/1506.02216 arXiv:1506.02216].
\[3\] Kihyuk Sohn, Honglak Lee, Xinchen Yan: "Learning Structured Output Representation using Deep Conditional Generative Models", 2015; [https://proceedings.neurips.cc/paper/2015/file/8d55a249e6baa5c06772297520da2051-Paper.pdf]
\[4\] Eric Jang, Shixiang Gu, Ben Poole: “Categorical Reparameterization with Gumbel-Softmax”, 2016; [http://arxiv.org/abs/1611.01144 arXiv:1611.01144].