owned this note
owned this note
Published
Linked with GitHub
# Hackathon IDRIS 2022: Deep OI
## 1.The 4DVarNet algorithm
Let $\mathbf{y}(\Omega)=\lbrace \mathbf{y}_k(\Omega_k) \rbrace$ denotes the partial and potentially noisy observational dataset corresponding to subdomain $\Omega=\lbrace \Omega_k \rbrace \subset \mathcal{D}$, $\overline{\Omega}$ denotes the gappy part of the field and index $k$ refers to time $t_k$. Using a data assimilation state space formulation, we aim at estimating the hidden space $\mathbf{x}=\lbrace \mathbf{x}_k(\Omega_k) \rbrace$ from the observations $\mathbf{y}$.
The core version of the code may be found here: [4DVarNet github](https://github.com/CIA-Oceanix/4dvarnet-core)
### 1.1 The variational model
Considering a variational data assimilation scheme, the state analysis $\mathbf{x}^\star$ is obtained by solving the minimization problem:
\begin{align*}
\mathbf{x}^\star= \underset{\mathbf{x}}{\arg\min} \ \mathcal{J}(\mathbf{x})
\end{align*}
where the variational cost function $\mathcal{J}(\mathbf{x})=\mathcal{J}_{\Phi}(\mathbf{x},\mathbf{y},\Omega)$ is generally the sum of an observation term and a regularization term involving an operator $\Phi$ which is typically a dynamical prior:
\begin{align*}
\mathcal{J}_{\Phi}(\mathbf{x},\mathbf{y},\Omega) & = \mathcal{J}^o(\mathbf{x},\mathbf{y},\Omega) + \mathcal{J}_{\Phi}^b(\mathbf{x})\\
& = \lambda_1 || \mathbf{y} - \mathcal{H}(\mathbf{x}) ||^2_\Omega + \lambda_2 || \mathbf{x} - \Phi(\mathbf{x}) ||^2
\end{align*}
with $\mathcal{H}$ the observation operator and $\lambda_{1,2}$ are predefined or learnable scalar weights. This formulation of functional $\cal{J}_{\Phi}(\mathbf{x},\mathbf{y},\Omega)$ directly relates to strong constraint 4D-Var.
For inverse problems with time-related processes, the minimization of functional $\mathcal{J}_\Phi$ usually involves iterative gradient-based algorithms and in particular request to consider the adjoint method in classic model-based variational data assimilation schemes where operator $\Phi$ identifies to a deterministic model $\mathbf{x}_{k+1}=\mathcal{M}(\mathbf{x}_{k})$:
\begin{align*}
\mathbf{x}^{(i+1)} = \mathbf{x}^{(i)} - \alpha \nabla_{\mathbf{x}}\mathcal{J}_{\Phi}(\mathbf{x}^{(i)} ,\mathbf{y},\Omega)
\end{align*}
In our case, we are interested in purely data-driven operator $\Phi$: we consider NN-based Gibbs-Energy (GENN) representations, a way of embedding Markovian priors in CNN which proves to be efficient on SSH altimetric datasets. This enables to use deep learning automatic differentiation tools: the computation of this gradient operator $\nabla_{\mathbf{x}}\mathcal{J}_{\Phi}$ given the architecture of operator $\Phi$ can be seen as a composition of operators involving tensors, convolutions and activation functions.
### 1.2 Trainable solver architecture
The proposed end-to-end architecture consists in embedding an iterative gradient-based solver based on the considered variational representation. As inputs, we consider an observation $\mathbf{y}$, the associated observation domain $\Omega$ and some initialization $\mathbf{x}^{(0)}$. Let us denote by $\Gamma$ this iterative update operator. Following meta-learning schemes, a residual LSTM-based representation of operator $\Gamma$ is considered here where the $i^{th}$ iterative update of the solver is given by:
\begin{align*}
\left \{\begin{array}{ccl}
g^{(i+1)}& = & LSTM \left[ \alpha \cdot \nabla_{\mathbf{x}}\mathcal{J}_{\Phi}(\mathbf{x}^{(i)} ,\mathbf{y},\Omega), h(i) , c(i) \right ] \\~\\
x^{(i+1)}& = & x^{(i)} - {\cal{T}} \left( g^{(i+1)} \right )
\end{array} \right.
\end{align*}
with $g^{(i+1)}$ is the LSTM output using as input gradient $\nabla_{\mathbf{x}}\mathcal{J}_{\Phi}(\mathbf{x}^{(i)} ,\mathbf{y},\Omega)$, while $h(i)$ and $c(i)$ denotes the internal states of the LSTM, $\alpha$ is a normalization scalar and ${\cal{T}}$ a linear or convolutional mapping.
Let note that a CNN architecture could also be used instead of the LSTM representation of $\Gamma$ and that when replacing both the LSTM cell by the identity operator and the minimization function $\mathcal{J}_{\Phi}(\mathbf{x},\mathbf{y},\Omega)$ by its single regularization term $\mathcal{J}_{\Phi}^b(\mathbf{x})$, the gradient-based solver simply leads to a parameter-free fixed-point version of the algorithm.
### 1.3 End-to-end joint learning scheme
Overall, let denote by $\Psi_{\Phi,\Gamma}(\mathbf{x}^{(0)},\mathbf{y},\Omega )$ the output of the end-to-end learning scheme given architectures for both NN-based operators $\Phi$ and $\Gamma$, see Figure below, the initialization $\mathbf{x}^{(0)}$ of state $\mathbf{x}$ and the observations $\mathbf{y}$ on domain $\Omega$.

Then, the joint learning of operators $\lbrace\Phi,\Gamma\rbrace$ is stated as the minimization of a reconstruction cost:
\begin{align*}
\arg \min_{\Phi,\Gamma} \mathcal{L}(\mathbf{x},\mathbf{x}^\star) \mbox{ s.t. }
\mathbf{x}^\star = \Psi_{\Phi,\Gamma} (\mathbf{x}^{(0)},\mathbf{y},\Omega)
\end{align*}
In case of supervised learning, where targets are gap-free:
$\mathcal{L}(\mathbf{x},\mathbf{x}^\star)=||\mathbf{x}-\mathbf{x}^\star||^2+||\nabla_\mathbf{x}-\nabla_{\mathbf{x}^\star}||^2$, i.e. the L2-norm of the difference between state $\mathbf{x}$ and reconstruction $\mathbf{x}^\star$ with an additional term related to the gradient of state $\mathbf{x}$.
In case of unsupervised learning, given the observations $\mathbf{y}$ on domain $\Omega$ and hidden state $\mathbf{x}$, the 4DVar cost function may be used $\mathcal{L}(\mathbf{x},\mathbf{x}^\star) = \lambda_1 || \mathbf{y} - \mathcal{H}(\mathbf{x}) ||^2_\Omega + \lambda_2 || \mathbf{x} - \Phi(\mathbf{x}) ||^2$ with weights $\lambda_1$ and $\lambda_2$ to adapt according to the reliability of the observations.
## 2.The Dataset
### 2.1 SSH altimetric dataset
The 4DVarNet algorithm has been successfully tested on small datasets, typically sequences of spatio-temporal images with sizes 7x200x200 (NtxNyxNx).


**OI (left) and 4DVarnet (right)**
#### 2.1.1 Location
In this Hackathon, the main objective is to reach or even beat state-of-the-art OI performance.
We use an hydra-based setup so the three files paths related to oi, observations and ground truth can be found in:
```
/gpfswork/rech/yrf/uba22to/4dvarnet-core/hydra_config/file_paths/dc_osse.yaml
```
They are stored using [NetCDF format](https://www.unidata.ucar.edu/software/netcdf/), which can be handled in Python using for instance the [xarray](https://xarray.pydata.org/en/stable/index.html) package.
#### 2.1.1.2 Description
All the three datasets (reference, data, optimal interpolation) have three dimensions:
- **time** (365)
- **lat** (761)
- **lon** (1721)
All the variables are stored as 2D 761x1721 regular grids along the 365 days with the following coordinates:
- **time** 365 days: _'2012-10-01' '2012-10-02' ... '2013-09-30'_
- **lat** 761 x 1/20° : _27.0 27.05 27.1 27.15 27.2 ... 64.85 64.9 64.95 65.0_
- **lon** 1721 x 1/20° : _-79.0 -78.95 -78.9 -78.85 ... 6.9 6.95 7.0_
##### 2.1.2.1 Ground Truth
- The SSH Ground Truth **NATL60-CJM165_NATL_ssh_y2013.1y.nc**:
- **ssh** (Sea surface Height) One year long daily datasets provided by the [NATL60](https://meom-group.github.io/swot-natl60/science.html) state-of-the-art oceanic simulation
- The SST complementary Ground Truth **NATL60-CJM165_NATL_sst_y2013.1y.nc**:
- **sst** (Sea surface Temperature): sst may be used for complementary tests to improve the ssh spatio-temporal interpolation
##### 2.1.2.2 Pseudo-observations
- The pseudo-observations dataset is generated by sampling the SSH Ground Truth with realistic satellite trajectories **dataset_nadir_0d_swot.nc** :
- **mask** : mask (1 -> ocean, 0 -> land)
- **lag** : time deviation (in hour) to the selected day
- **flag** : satellite type (0 -> NADIR, 1 -> SWOT)
- **ssh_obs**: data with additional realistic noise
- **ssh_mod**: data without noise (=model)
##### 2.1.2.3 Optimal interpolation (OI)
- the state-of-the-art optimal interpolation (OI) dataset based on the previous pseudo-observations. This is the baseline the 4DVarNet algorithm aims at improving **ssh_NATL60_swot_4nadir.nc**:
- **ssh_obs**: OI using ssh_obs in the pseudo-observations dataset
- **ssh_mod**: OI using ssh_mod in the pseudo-observations dataset
### 2.2 SPDE-based GP dataset
For this new dataset, we generate spatio-temporal SPDE-based GP simulation for local anisotropies.
In this data challenge, the parameter $\alpha$ of the SPDE:
$\frac{\partial{\mathbf{x}}}{\partial{t}}+(\kappa^2(\mathbf{s},t)-\nabla \mathbf{H}(\mathbf{s})\nabla)^{\alpha/2} \mathbf{x}(\mathbf{s},t)=\tau \mathbf{z}(\mathbf{s},t)$
is set to 4 so that the resulting field is smooth enough and not too noisy. This leads to spatio-temporal fields more compliant with realistic geophysical datasets. The users may play with the datasets to generate alternate simulations for their own use.
Pseudo-observations are generated based on this simulations with moving masks across space and time.
Last, the SPDE-based Optimal Interpolation is also given based on a data assimilation window of 5 time steps, centered around the leading time to estimate.
Simulation (Ground Truth) | Pseudo-observation
:-------------------------:|:-------------------------:
 | 
Again, thanks to the hydra-based setup the dataset informations can be found in:
```
/gpfswork/rech/yrf/uba22to/4dvarnet-core/hydra_config/file_paths/dc_gp.yaml
```
## 3 Notes
* Let note that the in the first version of the algorithm, 4DVarNet was applied on the anomaly $\mathbf{x}-\overline{\mathbf{x}}$ between the raw ssh and the OI (denoted here as $\overline{\mathbf{x}}$ and seen as a large scale component of the SSH). This solution was for now the one giving the best results, **with spatio-temporal size of the patches=200(lon) x 200(lat) x 7(days)**. The OI was also used as additional input channels in the 4DVarNet algorithm and thus can be seen as an extra covariate that helps to localize the areas wth large anomalies.
* In the new version we would like to test in the 2022 Hackathon GPU Idris, we will work on the raw variables, meaning that temporal correlations are longer, and we intend to use:
* **spatio-temporal size of the patches=240(lon) x 240(lat) x 29(days)** for the SSH dataset
* **spatio-temporal size of the patches=100(x) x 100(y) x ??(t)** for the GP dataset
## 4 Code available on Github
Here is the repo GitHub with a simplified version of our code: [4DVarNet](https://github.com/CIA-Oceanix/4dvarnet-core) and some explanations regarding its architecture [architecture of the code](https://github.com/CIA-Oceanix/4dvarnet-core/blob/main/doc/code_archi.md)
Key components of the code:
* 4DvarNet pytorch Lightning module
* Multi-GPU/multi-node distribution using pytorch lightning
* Possible on-the-fly batch generation from raw files
In the branch **mbeaucha**, a folder **oi** contains pieces of codes rekated to the Hackathon experiments.
In particular we can find some notebooks to advance quickly on these tasks and help to reproduce the preliminary results
## 5 Workplan for the Hackathon
Notebooks are provided as starting points for both SSH-OSSE mapping and GP mapping datasets.
For the SSH-OSSE mapping, tne notebook can be found in:
```
/gpfswork/rech/yrf/uba22to/4dvarnet-core/oi/eval_notebooks/eval_4dvarnet_OI_OSSE.ipynb
```
For the GP mapping (xp2-1 et xp2-2), the notebooks can be found in:
```
/gpfswork/rech/yrf/uba22to/4dvarnet-core/oi/eval_notebooks/eval_4dvarnet_OI_GP.ipynb
/gpfswork/rech/yrf/uba22to/4dvarnet-core/oi/eval_notebooks/eval_SPDE_LSTM_OI_GP.ipynb
```
In these Notebooks, the 4DVarNet code is launched and some metrics are computed.
### 5.1 Experience 1
* xp1-1: Reproduce the metrics obtained between a multi-GPU and a mono-GPU implementation of the code
* scale learning-rates
* mini-batch normalization
* increase batch size (test Quentin: batch size=5)
* use mixed precision (test Quentin: loss decrease)
* test other optimizers
* xp1-2:
* take oi/models.py from mbeaucha 4dvarnet-core branch and plug it into the main branch
* update resize_factor option (see hydra_main and oi/dataloader_osse from mbeaucha branch)
* learning GT with single scale estimation (reproduce anomaly scores)
* learning GT with two scale estimation (reproduce anomaly scores)
* Leaderboard:
| Method | µ(RMSE) | σ(RMSE) | λx (dx=1) | λt (dt=1) | Notes | Reference |
|:-----------|------------------------:|---------------------:|-------------------------:|-----------------------:|:--------------------------|:-----------------|
| baseline DUACS OI :trophy: | **0.91** | **0.02** | **1.43** | **11.29** | SPDE-based precision matrix | eval_notebooks/eval_4dvarnet_oi_OSSE.ipynb |
| | | | | | | |
| 4DVarNet-anomaly (NN 2 scale) | | | | | 4DVarNet mapping | eval_notebooks/eval_4dvarnet_oi_OSSE.ipynb |
| 4DVarNet-OI (NN 1 scale)| | | | | 4DVarNet mapping | eval_notebooks/eval_4dvarnet_oi_OSSE.ipynb |
| 4DVarNet-OI (NN 2 scale) | | | | | 4DVarNet mapping | eval_notebooks/eval_4dvarnet_oi_OSSE.ipynb |
### 5.2 Experience 2
* xp2-1: Prior NN + Solver LSTM ( classic 4DVarNet)
* xp2-2: Prior SPDE (H given) + Solver LSTM
* xp2-3: Prior SPDE (H to estimate) + Solver LSTM
* xp-2*:
* Run a profiler
* search for improvements
* prior identification
* MLE:
* $\mathcal{L}(\mathbf{x}|\hat{\mathbf{Q}})=\frac{1}{2}\big[-p\mathrm{ln}(2\pi) + \mathrm{ln}(| \hat{\mathbf{Q}}|) - \mathbf{x}^{T}\hat{\mathbf{Q}}\mathbf{x} \big]$
* où $p=Nx \times Ny$, et $\mathrm{ln}(| \hat{\mathbf{Q}}|)=\mathrm{ln}(| \hat{\mathbf{L}}\hat{\mathbf{L}}^{T}|)=\mathrm{ln}\big(\big\lbrace\prod\mathrm{diag}(\hat{\mathbf{L}}\big\rbrace)^2) \big)$
* Leaderboard:
| Method | µ(RMSE) | σ(RMSE) | λx (dx=1) | λt (dt=1) | Notes | Reference |
|:-----------|------------------------:|---------------------:|-------------------------:|-----------------------:|:--------------------------|:-----------------|
| baseline SPDE OI :trophy: | **0.31** | **0.04** |**26.61** | **17.21** | SPDE-based precision matrix | eval_notebooks/eval_4dvarnet_OI_GP.ipynb |
| | | | | | | |
| 4DVarNet | | | | | 4DVarNet mapping | eval_notebooks/eval_4dvarnet_OI_GP.ipynb |
| Prior SPDE (H given) + Solver LSTM | | | | | Prior SPDE (H given) + Solver LSTM mapping | eval_notebooks/eval_SPDE_LSTM_OI_GP.ipynb |
| Prior SPDE (H) + Solver LSTM | | | | | Prior SPDE (H) + Solver LSTM mapping | eval_notebooks/eval_SPDE_LSTM_OI_GP.ipynb |