## A differentiable JAX-powered Shallow Water Model --- ## Motivation A (flexible) differentiable code for dynamical systems - For testing ideas on parameter inference, closures, model errors and state estimation. - Key requirement : embedding time-stepping in the learning. - Concept : a common interface for different use cases in climate science ---- ## Why Jax? - numpy syntax - `vmap`/`pmap`, `jit`, `grad`/`jacobian`/`hessian` - CPU, GPU, TPU - computational performance for CFD problems - a growing community in Sci-ML - it's fun ! --- ## Use Cases (1/2) - ML-based eddy closures (boundaries, end-to-end, online/offline) - Learning and representing model errors (from DA or through end-to-end-learning) - Testing ideas on the inversion of DA increments ---- ## Use Cases (2/2) - Simple model based inversion of satellite observations (altimetry) - Data-driven eddy-wave separation techniques - Parametrizations with a free gradient access --- ## Our Vision --- ![](https://i.imgur.com/PPWnyiE.jpg) ---- #### Requirements (1/2) - Common API for several dynamical systems - Flexible interface for easy exp and model set-up - Deploy parameterizations in a few lines of code. - Easy access to gradients - Implicit differentiation for long integrations. ---- #### Requirements (2/2) - Handle irregular and interior boundaries. - Allow for different numerical schemes (FD, FV). - Good computational performance for CFD - Possibility to deploy complex obs operator ---- #### Core structure ![](https://i.imgur.com/9HJLFEL.jpg) ---- #### PseudoCode ```python= class Solver state : Variables = None params : {String: Any} = None time : Float = 0 def rhs(state: Variables, params: {String: Any}, scheme: ?) -> Variables = None def set_bc(state: Variables, bc: BCond) -> Variables = None def forward(state: Variables, delta: Float, timestepper: Stepper) -> Variables class Burger1D(Solver) params : {String: Any} = {'mu': Float} ``` ---- ##### Components ```python= class Grid class Variable grid : Grid class Stepper def advance(delta: Float,rhs: Callable, state: Variables) = None class BCond def set(state: Variables) -> Variables ``` ---- ###### Example ```python= hr_solver=Burgers1D(hr_grid,bcond) hr_data = hr_solver.forward(stepper, dt, t_final) lr_solver=Burgers1D(lr_grid,bcond) lr_solver.append_rhs(lambda: state + neural_net(state)) def online_mse(lr_solver): return mse(hr_data, lr_solver.forward(stepper, dt, t_final)) grad=jax.grad(online_mse)(lr_solver) ... ``` --- # Reconnaissances ---- ## Existing Libraries * Specific: * [jaxdf](https://github.com/ucl-bug/jaxdf) - differentiable numerical discretizations * [Fenics](https://github.com/IvanYashchuk/jax-fenics) - PDE solvers for JAX * High-Level: [Dynamical System](https://github.com/googleinterns/invobs-data-assimilation/blob/master/dynamical_system.py) - for an API inspiration for users; it uses the above CFD library and provides an easy wrapper for users ---- ### Jax-CFD Jax for Computational Fluid Dynamics ---- * core finite volume/difference methods for CFD, written in JAX. * pseudospectral methods for CFD * machine learning augmented models for CFD, written in JAX and Haiku. * data processing Xarray and Pillow. ---- #### Verdict * ✔️ - low level API (similar to ours) * ✔️ - well-thought out * ❌ - missing a LOT of features (Finite Volume, Boundary Conditions) * ❌ - only for NS --- ## Next Steps - Coordinate future collaboration (when2meet) - Build from scratch for simple use cases - 1D FV with wall boundaries - Build on our scratch code and Jax-CFD to address all stated requirements - Potential Collaboration with JAX-CFD devs
{"metaMigratedAt":"2023-06-17T01:14:29.050Z","metaMigratedFrom":"YAML","title":"Reconnaissances","breaks":true,"slideOptions":"{\"transition\":\"slide\"}","contributors":"[{\"id\":\"091b5fca-4484-4642-a66d-92cabb3e28b6\",\"add\":1248,\"del\":902},{\"id\":\"fd5920e0-5926-4a73-984a-744c9ad970e0\",\"add\":1728,\"del\":356},{\"id\":\"8787e800-8eb4-4276-9d1f-1a8137049f08\",\"add\":361,\"del\":341},{\"id\":\"a0844d61-245b-412e-a68b-f63a5852608a\",\"add\":1445,\"del\":484},{\"id\":\"9a5b9095-c158-437a-8b5d-5e4729a8e0b7\",\"add\":1235,\"del\":363},{\"id\":\"a08b5eda-ffd8-4a68-94a3-c8bab33dc7ac\",\"add\":216,\"del\":15}]"}
    243 views
   owned this note