<style>
.reveal {
font-size: 28px;
}
</style>
<style>
.green {color: green;}
</style>
<style>
.red {color: red;}
</style>
# Facilitating online training in Fortran-based climate models
<u>Joe Wallwork</u><sup>1</sup>, Jack Atkinson<sup>1</sup>, Dominic Orchard<sup>1</sup>, et al.
<sup>1</sup>Institute of Computing for Climate Science, University of Cambridge, U.K.
<!---->
<img src="https://hackmd.io/_uploads/SyNht0cpyg.png" alt="drawing" width="400"/>
<img src="https://hackmd.io/_uploads/ryH0C69a1l.png" alt="drawing" width="300"/>
Slides: https://hackmd.io/@jwallwork/euroad-2025?type=slide
<!-- 30 minute slot. -->
<!-- Abstract
Machine learning (ML) based techniques are becoming increasingly popular in numerical simulation, bringing potentially significant gains in efficiency. Whilst popular ML tools such as PyTorch are written in Python, the climate modelling community continues to make heavy use of Fortran for their scientific models, which lacks native ML support. This presents an issue for users because of the challenges of language interoperation. One approach to make use of ML models in Fortran is to use a Fortran interface to PyTorch, such as FTorch. FTorch has supported “offline training” for some time, whereby models are designed and trained in Python, saved to file, and then loaded to run inference in Fortran. In this talk, we will be sharing the latest developments to support “online training”, where the training is done while the Fortran model is running. Online training brings the benefits of avoiding unnecessarily archiving of large volumes of training data and being able to define cost functions in terms of errors involving downstream model code. The talk will detail our approach for enabling online training in FTorch by exposing PyTorch’s autograd module, providing a seamless interface which should be familiar to users of both Fortran and PyTorch. The new functionality will be demonstrated in a climate modelling context.
-->
---
## Funding
* The [Institute of Computing for Climate Science (ICCS)](https://iccs.cam.ac.uk) acknowledges funding from [Schmidt Sciences](https://www.schmidtsciences.org).
* This project also received funding from a [C2D3-Accelerate grant](https://science.ai.cam.ac.uk/news/2024-12-09-exploring-novel-applications-of-ai-for-research-and-innovation-%E2%80%93-announcing-our-2024-funded-projects.html) for novel applications of AI in research and innovation.
---
## Motivation - climate models
<!--
* Climate models provide projections of possible future conditions given different emissions scenarios.
* These inform policy on tackling climate change.
* Adaptation: accurate projections allow us to account for the changing climate in future planning.
-->

[*International Panel on Climate Change (IPCC) AR6 Synthesis report
(Summary for Policymakers), 2023.*](https://www.ipcc.ch/report/ar6/syr/downloads/report/IPCC_AR6_SYR_SPM.pdf)
---
## Motivation - hybrid models
* High quality climate simulations are expensive.
* Based on Earth System Models (ESMs), which have many sub-components.
* Machine learning (ML) can be used to emulate/accelerate model components.
<!---->

<!--
Mention spectrum from physical models through to end-to-end ML models.
1. Classical models
2. Mostly classical model + ML used to improve upon hand-tuned parametrisations.
3. Hybrid model with whole sub-models replaced by emulators.
4. End-to-end ML model.
-->
---
## But it's 2025... why Fortran?
* Almost all climate models are written in Fortran.
* Fortran arrays natively support mathematical operations $\implies$ a natural choice for scientific computing.
* Low-level language with reasonable support from GPU vendors.
<!---->

---
## FTorch
* Fortran interface for PyTorch, https://cambridge-iccs.github.io/FTorch.
* Uses `iso_c_binding` to interface with the Torch C++ backend. <!-- intrinsic to Fortran since the 2003 standard. -->
* Designed to be familiar to both Fortran programmers and PyTorch users.
* Couple directly to `libtorch` $\implies$ no need for Python runtime.
* Support for CUDA, XPU, and MPS GPU devices. <!-- Plan to look at AMD soon. -->
```fortran
!> Type for holding a Torch tensor.
type torch_tensor
type(c_ptr) :: p = c_null_ptr !! pointer to the tensor in memory
contains
! ...
end type torch_tensor
```
---
## FTorch offline training workflow
<!---->

<!--
0. Design ML model in PyTorch.
Then four distinct steps:
1. Run Fortran simulations, saving training data to file.
2. Train the ML model in PyTorch, writing model to file.
3. Convert model file to TorchScript using `pt2ts` utility.
4. Read TorchScript model from Fortran to run inference.
-->
---
## Proposed FTorch online training workflow
<!---->
<!---->

<!--
Now there's only two distinct steps, one of which is very minor.
-->
---
## Online training - pros and cons
* <span class="red">-Difficult to implement in most frameworks.</span>
* <span class="green">+Avoids saving large volumes of training data. </span>
* <span class="green">+Avoids need to convert between Python and Fortran data formats.</span>
* <span class="green">+Possibility to expand loss function scope to include downstream model code.</span>
---
## Online training - expanded loss function
Suppose we want to use a loss function involving downstream model code, e.g.,
$$J(\theta)=\int_\Omega(u-u_{ML}(\theta))^2\;\mathrm{d}x,$$
where $u$ is the solution from the physical model and $u_{ML}(\theta)$ is the solution from a hybrid model with some ML parameters $\theta$.
Computing $\mathrm{d}J/\mathrm{d}\theta$ requires differentiating Fortran code as well as ML code.
---
## Implementing AD in FTorch
* Expose `autograd` functionality from Torch.
* e.g., `requires_grad` argument and `backward` methods.
* Overload mathematical operators (`=`,`+`,`-`,`*`,`/`,`**`).
```fortran
interface operator (*)
module procedure torch_tensor_multiply
end interface
!> Overloads multiplication operator for two tensors.
function torch_tensor_multiply(tensor1, tensor2) result(output)
use, intrinsic :: iso_c_binding, only : c_associated
type(torch_tensor), intent(in) :: tensor1 !! First tensor to be multiplied
type(torch_tensor), intent(in) :: tensor2 !! Second tensor to be multiplied
type(torch_tensor) :: output !! Tensor to hold the product
! [CC interface definition]
call torch_tensor_multiply_c(output%p, tensor1%p, tensor2%p)
end function torch_tensor_multiply
```
---
## Difficulty 1: overloading scalar-tensor operations
```fortran
interface operator(*)
module procedure torch_tensor_multiply ! tensor * tensor
#:for PREC in PRECISIONS
module procedure torch_tensor_premultiply ! scalar * tensor
module procedure torch_tensor_postmultiply ! tensor * scalar
#:endfor
end interface
interface operator(/)
module procedure torch_tensor_divide ! tensor / tensor
#:for PREC in PRECISIONS
module procedure torch_tensor_postdivide ! tensor / scalar
#:endfor
end interface
```
<!-- Mention fypp -->
* Difficult to account for various data types.
* Unclear how to use `torch::Scalar`.
* Workaround: use 0D tensors as scalars.
<!-- FTorch supports 4 integer precisions and 2 real precisions, which proves tricky. -->
---
## Using AD in PyTorch
```python
"""Based on https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html."""
import torch
# Construct input tensors with requires_grad=True
a = torch.tensor([2.0, 3.0], requires_grad=True)
b = torch.tensor([6.0, 4.0], requires_grad=True)
# Compute some mathematical expression
Q = 3 * (a**3 - b * b / 3)
# Reverse mode
Q.backward(gradient=torch.ones_like(Q))
print(a.grad)
print(b.grad)
```
<!-- Note that external gradient is required unless scalar output. -->
---
## Using AD in FTorch
```fortran
use ftorch
type(torch_tensor) :: a, b, Q, multiplier, divisor, dQda, dQdb
real, dimension(2), target :: Q_arr, dQda_arr, dQdb_arr
! Construct input tensors with requires_grad=.true.
call torch_tensor_from_array(a, [2.0, 3.0], torch_kCPU, requires_grad=.true.)
call torch_tensor_from_array(b, [6.0, 4.0], torch_kCPU, requires_grad=.true.)
! Workaround for scalar multiplication and division using 0D tensors
call torch_tensor_from_array(multiplier, [3.0], torch_kCPU)
call torch_tensor_from_array(divisor, [3.0], torch_kCPU)
! Compute some mathematical expression
call torch_tensor_from_array(Q, Q_arr, torch_kCPU)
Q = multiplier * (a**3 - b * b / divisor)
! Reverse mode
call torch_tensor_backward(Q)
call torch_tensor_from_array(dQda, dQda_arr, torch_kCPU)
call torch_tensor_from_array(dQdb, dQdb_arr, torch_kCPU)
call torch_tensor_get_gradient(a, dQda)
call torch_tensor_get_gradient(b, dQdb)
print *, dQda_arr
print *, dQdb_arr
```
---
## (Minor) difficulty 2: Extracting gradients
* Need to call `torch_tensor_get_gradient` after each call to `torch_tensor_backward` or `torch_tensor_zero_grad`.
* Due to pointer management on C++ side, probably avoidable.
---
## Online training in FTorch - `Optimizer` classes
* Expose `torch::optim::Adam`, `torch::optim::SGD`, etc., as well as their `zero_grad` and `step` methods.
* This already enables some cool AD applications in FTorch. <!-- e.g., ODE/PDE-constrained optimization -->
---
## Online training in FTorch - loss functions
* We haven't exposed any built-in loss functions yet.
* Implemented `torch_tensor_sum` and `torch_tensor_mean`, though.
---
## Putting it together - running an optimiser in FTorch
$$\begin{bmatrix}f_1\\f_2\\f_3\\f_4\end{bmatrix}=\mathbf{f}(\mathbf{x};\mathbf{a})=\mathbf{a}\bullet\mathbf{x}\equiv\begin{bmatrix}a_1x_1\\a_2x_2\\a_3x_3\\a_4x_4\end{bmatrix}$$
Starting from $\mathbf{a}=\mathbf{x}:=\begin{bmatrix}1,1,1,1\end{bmatrix}^T$, optimise the $\mathbf{a}$ vector such that $\mathbf{f}(\mathbf{x};\mathbf{a})=\mathbf{b}:=\begin{bmatrix}1,2,3,4\end{bmatrix}^T$.
Loss function: $\ell(\mathbf{a})=\overline{(\mathbf{f}(\mathbf{x};\mathbf{a})-\mathbf{b})^2}$.
---
## Putting it together - running an optimiser in FTorch

In both cases we achieve $\mathbf{f}(\mathbf{x};\mathbf{a})=\begin{bmatrix}1,2,3,4\end{bmatrix}^T$.
---
## Difficulty 3: `Module` classes
* FTorch currently uses `torch::jit::Module` (optimised for inference).
* Online training will need components from `torch::nn::Module`.
* Recall workflow:
<!---->

<!--
Two ways we could go:
1. Expose `torch::nn::Module`; users construct nets in Fortran.
Downside: there's an endless amount to bring over and there would be some duplication.
2. Give the users the responsibility to write the net in C++.
Downside: users need to know C++.
-->
---
## Case studies - UKCA
* Implicit timestepping, quasi-Newton, full LU decomposition.
* For each time subinterval to be integrated:
* Start with $\Delta t=3600$.
* Try to integrate with the current timestep size.
* If *any grid-box* fails, half the step and try again.
<!--
* In many cases, the default large timestep is insufficient.
* This is particularly true for GPU port with large chunk size.
* Redundant calculation, which it would be good to avoid.
* Predict required timestep using a simple ML model.
-->
---
## Case studies - MiMA

[*Espinoza et al. (2022), Machine Learning Gravity Wave Parameterization Generalizes to Capture the QBO and Response to Increased CO$_2$*, Geophysical Research Letters.](https://doi.org/10.1029/2022GL098174)
<!--
* MiMA = Model of an Idealized Moist Atmosphere
* MiMA captures key dynamical features of the stratosphere-troposphere system that depend critically on GWs at a resolution comparable to state-of-the-art stratosphere resolving global climate models
-->
<!--
Caption: Pressure-time profiles of the zonal mean zonal wind, averaged between 5°S and 5°N and smoothed with a 15-day low pass filter, show the behavior of the Quasi-Biennial Oscillation (QBO) in integrations of (a) the control version of Model of an Idealized Moist Atmosphere with the AD99 parameterization, (b) the model coupled with WaveNet, (c) a 4xCO2 integration with the AD99 parameterization, and (d) a 4xCO 2 integration coupled with WaveNet. Vertical dashed lines separate 5 years segments. The westerly (red) and easterly (blue) bands correspond to winds associated with opposite phases of the QBO. The QBO period and amplitudes are calculated using the transition time (TT) method. The dashed-horizontal line in each panel delineates the model level (≈10 hPa) where the TT method is used.
-->
<!-- Confusingly for us, AD99 is the model without ML (or AD). (Alexander and Dunkerton 1999 GWP.) It aims to capture the effect of non-orographic GW drag. -->
<!-- (a) 14, (b) 13, (c) 17 (d) 16 -->
---
## Summary
* FTorch provides a Fortran interface for PyTorch.
* Torch `autograd` functionality exposed using `iso_c_binding`.
* Exposed tools for optimization.
* Work in progress on setting up online ML training.
---
## Resources
* FTorch webpage: https://cambridge-iccs.github.io/FTorch.
* Atkinson et al., (2025). FTorch: a library for coupling PyTorch models to Fortran. Journal of Open Source Software, 10(107), 7602, https://doi.org/10.21105/joss.07602.
* [ICCS ML coupling workshop](https://cambridge-iccs.github.io/ml-coupling-workshop) - 3-4 September, Cambridge, U.K.
{"title":"EuroAD talk","description":"Slides for EuroAD presentation","contributors":"[{\"id\":\"033ac354-bcb8-4c50-8db3-75282f8d798a\",\"add\":20361,\"del\":6063}]"}