## A differentiable JAX-powered Shallow Water Model
---
## Motivation
A (flexible) differentiable code for dynamical systems
- For testing ideas on parameter inference, closures, model errors and state estimation.
- Key requirement : embedding time-stepping in the learning.
- Concept : a common interface for different use cases in climate science
----
## Why Jax?
- numpy syntax
- `vmap`/`pmap`, `jit`, `grad`/`jacobian`/`hessian`
- CPU, GPU, TPU
- computational performance for CFD problems
- a growing community in Sci-ML
- it's fun !
---
## Use Cases (1/2)
- ML-based eddy closures (boundaries, end-to-end, online/offline)
- Learning and representing model errors (from DA or through end-to-end-learning)
- Testing ideas on the inversion of DA increments
----
## Use Cases (2/2)
- Simple model based inversion of satellite observations (altimetry)
- Data-driven eddy-wave separation techniques
- Parametrizations with a free gradient access
---
## Our Vision
---

----
#### Requirements (1/2)
- Common API for several dynamical systems
- Flexible interface for easy exp and model set-up
- Deploy parameterizations in a few lines of code.
- Easy access to gradients
- Implicit differentiation for long integrations.
----
#### Requirements (2/2)
- Handle irregular and interior boundaries.
- Allow for different numerical schemes (FD, FV).
- Good computational performance for CFD
- Possibility to deploy complex obs operator
----
#### Core structure

----
#### PseudoCode
```python=
class Solver
state : Variables = None
params : {String: Any} = None
time : Float = 0
def rhs(state: Variables, params: {String: Any}, scheme: ?) -> Variables = None
def set_bc(state: Variables, bc: BCond) -> Variables = None
def forward(state: Variables, delta: Float, timestepper: Stepper) -> Variables
class Burger1D(Solver)
params : {String: Any} = {'mu': Float}
```
----
##### Components
```python=
class Grid
class Variable
grid : Grid
class Stepper
def advance(delta: Float,rhs: Callable, state: Variables) = None
class BCond
def set(state: Variables) -> Variables
```
----
###### Example
```python=
hr_solver=Burgers1D(hr_grid,bcond)
hr_data = hr_solver.forward(stepper, dt, t_final)
lr_solver=Burgers1D(lr_grid,bcond)
lr_solver.append_rhs(lambda: state + neural_net(state))
def online_mse(lr_solver):
return mse(hr_data, lr_solver.forward(stepper, dt, t_final))
grad=jax.grad(online_mse)(lr_solver)
...
```
---
# Reconnaissances
----
## Existing Libraries
* Specific:
* [jaxdf](https://github.com/ucl-bug/jaxdf) - differentiable numerical discretizations
* [Fenics](https://github.com/IvanYashchuk/jax-fenics) - PDE solvers for JAX
* High-Level: [Dynamical System](https://github.com/googleinterns/invobs-data-assimilation/blob/master/dynamical_system.py) - for an API inspiration for users; it uses the above CFD library and provides an easy wrapper for users
----
### Jax-CFD
Jax for Computational Fluid Dynamics
----
* core finite volume/difference methods for CFD, written in JAX.
* pseudospectral methods for CFD
* machine learning augmented models for CFD, written in JAX and Haiku.
* data processing Xarray and Pillow.
----
#### Verdict
* ✔️ - low level API (similar to ours)
* ✔️ - well-thought out
* ❌ - missing a LOT of features (Finite Volume, Boundary Conditions)
* ❌ - only for NS
---
## Next Steps
- Coordinate future collaboration (when2meet)
- Build from scratch for simple use cases
- 1D FV with wall boundaries
- Build on our scratch code and Jax-CFD to address all stated requirements
- Potential Collaboration with JAX-CFD devs
{"metaMigratedAt":"2023-06-17T01:14:29.050Z","metaMigratedFrom":"YAML","title":"Reconnaissances","breaks":true,"slideOptions":"{\"transition\":\"slide\"}","contributors":"[{\"id\":\"091b5fca-4484-4642-a66d-92cabb3e28b6\",\"add\":1248,\"del\":902},{\"id\":\"fd5920e0-5926-4a73-984a-744c9ad970e0\",\"add\":1728,\"del\":356},{\"id\":\"8787e800-8eb4-4276-9d1f-1a8137049f08\",\"add\":361,\"del\":341},{\"id\":\"a0844d61-245b-412e-a68b-f63a5852608a\",\"add\":1445,\"del\":484},{\"id\":\"9a5b9095-c158-437a-8b5d-5e4729a8e0b7\",\"add\":1235,\"del\":363},{\"id\":\"a08b5eda-ffd8-4a68-94a3-c8bab33dc7ac\",\"add\":216,\"del\":15}]"}