<style>
.reveal {
font-size: 27px;
}
</style>
<style>
.green {color: green;}
</style>
<style>
.red {color: red;}
</style>
## Facilitating machine learning in Fortran using FTorch
<u>Joe Wallwork</u><sup>1</sup>, Jack Atkinson<sup>1</sup>, Niccolò Zanotti<sup>1,2</sup>, Dominic Orchard<sup>1,3</sup> et al.
<sup>1</sup>Institute of Computing for Climate Science, University of Cambridge, U.K.
<sup>2</sup>University of Bologna, Italy. <sup>3</sup>University of Kent, U.K.
<!---->
<img src="https://hackmd.io/_uploads/SyNht0cpyg.png" alt="drawing" width="400"/>
<img src="https://hackmd.io/_uploads/ryH0C69a1l.png" alt="drawing" width="300"/>
Slides: https://hackmd.io/@jwallwork/darc-seminar-2025?type=slide
<!-- 60 minute slot. 30-35 for presentation and then 25-30 for discussion. -->
<!-- Abstract
Fortran continues to be important to several scientific communities - including weather and climate forecasting - thanks to its natural support for mathematical operations on arrays, continued updates to the language standards and compilers, and considerable societal inertia. However, it lacks native support for machine learning (ML), for which popular tools such as PyTorch are written in Python. Making use of the iso_c_binding intrinsic, FTorch (https://github.com/Cambridge-ICCS/FTorch) facilitates the easy deployment of PyTorch-based models within Fortran codes by interfacing directly to the C++ Torch backend, avoiding the need for a Python runtime. In this talk, we will detail the FTorch approach to address the language inter-operation problem, showcase research that FTorch has enabled, and discuss our current and future development plans, including automatic differentiation and online training.
-->
---
## Funding
* The [Institute of Computing for Climate Science (ICCS)](https://iccs.cam.ac.uk) acknowledges funding from [Schmidt Sciences](https://www.schmidtsciences.org).
* This project also received funding from a [C2D3-Accelerate grant](https://science.ai.cam.ac.uk/news/2024-12-09-exploring-novel-applications-of-ai-for-research-and-innovation-%E2%80%93-announcing-our-2024-funded-projects.html) for novel applications of AI in research and innovation.
---
# Overview
---
## Motivation - hybrid models
* High quality climate simulations are expensive.
* Based on Earth System Models (ESMs), which have many sub-components.
* Machine learning (ML) can be used to emulate/accelerate model components.
<!---->

<!--
Mention spectrum from physical models through to end-to-end ML models.
1. Classical models
2. Mostly classical model + ML used to improve upon hand-tuned parametrisations.
3. Hybrid model with whole sub-models replaced by emulators.
4. End-to-end ML model.
-->
---
## Challenges
* Reproducibility (ensure ML functions to same in-situ).
* Reusability
* Making the ML approach available to many models.
* Facilitating easy retraining and adaptation.
* Language-interoperation
---
## But it's 2025... why Fortran?
Most weather and climate models are written in Fortran, e.g, IFS, UM, LFRic, ICON, WRF, CESM.

<sup><sup><br/>[Mathematical Bridge](https://en.wikipedia.org/wiki/Mathematical_Bridge) by [cmglee](https://commons.wikimedia.org/wiki/User:Cmglee) used under [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0/deed.en)</sup></sup>
Much machine learning is conducted in Python, e.g., PyTorch, Tensorflow.
---
## But it's 2025... why Fortran?
Most weather and climate models are written in Fortran, e.g, IFS, UM, LFRic, ICON, WRF, CESM.

<sup><sup><br/>[Mathematical Bridge](https://en.wikipedia.org/wiki/Mathematical_Bridge) by [cmglee](https://commons.wikimedia.org/wiki/User:Cmglee) used under [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0/deed.en)</sup></sup>
Much machine learning is conducted in Python, e.g., PyTorch, Tensorflow.
* Fortran arrays natively support mathematical operations $\implies$ a natural choice for scientific computing.
* Low-level language with reasonable support from GPU vendors.
---
## Possible approaches
* Implement the ML code in Fortran, e.g., [neural-fortran](https://github.com/modern-fortran/neural-fortran), [fiats](https://github.com/BerkeleyLab/fiats)
* Additional work, reproducibility issues, hard for complex architectures.
* [Forpy](https://github.com/ylikx/forpy)
* Easy to add, harder to use with ML, requires Python runtime.
* [SmartSim](https://github.com/CrayLabs/SmartSim)
* Python 'control centre' around Redis: generic/versatile, learning curve, data copying.
* [Infero](https://github.com/ecmwf/infero)
* Thin layer on top of interchangeable inference engines, multiple languages, data copying.
* [TorchFort](https://github.com/NVIDIA/TorchFort)
* No data copying, Nvidia-specific GPU support.
* [Fortran-Keras bridge](https://github.com/scientific-computing/FKB)
* Keras-only, abandonware.
---
## FTorch - overview
Fortran interface for PyTorch, https://cambridge-iccs.github.io/FTorch.

* Open source (MIT license) and open development.
* Designed to be familiar to both Fortran programmers and PyTorch users.
* Extensive unit tests and plenty of examples.
* Fortran community tooling (FORD, pFUnit, Fortitude).
---
## FTorch - technical details
* Uses `iso_c_binding` to interface with the Torch C++ backend. <!-- intrinsic to Fortran since the 2003 standard. -->
* Heavy use of pointers $\implies$ no data copying.
<!-- Object-oriented Fortran -->
```fortran
!> Type for holding a Torch tensor.
type torch_tensor
type(c_ptr) :: p = c_null_ptr !! pointer to the tensor in memory
contains
! ...
end type torch_tensor
```
* Couple directly to `libtorch` $\implies$ no need for Python runtime.
* Support for CUDA, XPU, and MPS GPU devices. <!-- Plan to look at AMD soon. -->
* Easy to build and link with CMake or GNU Make.
* Utility for converting PyTorch models to TorchScript.
---
## Pointer to Fortran array
```fortran
use, intrinsic :: iso_fortran_env, only : sp => real32
use :: ftorch
implicit none
! Fortran variables
real(sp), dimension(1,3,244,244), target :: in_data
real(sp), dimension(1, 1000), target :: out_data
integer, parameter :: n_inputs = 1
! Torch Tensors
type(torch_tensor), dimension(1) :: in_tensors
type(torch_tensor) :: out_tensor
! Populate Fortran data
call random_number(in_data)
! Create input/output tensors from the above arrays
call torch_tensor_from_array(in_tensors(1), in_data, torch_kCPU)
call torch_tensor_from_array(out_tensor, out_data, torch_kCPU)
```
---
## FTorch (offline) training workflow
<!---->
<!---->

<!--
0. Design ML model in PyTorch.
Then four distinct steps:
1. Run Fortran simulations, saving training data to file.
2. Train the ML model in PyTorch, writing model to file.
3. Convert model file to TorchScript using `pt2ts` utility.
4. Read TorchScript model from Fortran to run inference.
-->
---
## Case studies
* [MiMA](https://github.com/DataWaveProject/MiMA-machine-learning) (DataWave): Emulation of existing parametrisation. Replaced Forpy approach with FTorch, with performance gains.
* "Identical" neural networks have very different behaviours when deployed for inference ([Mansfield and Sheshadri, 2024](https://doi.org/10.1029/2024MS004292)).
* [CAM-GW](https://github.com/DataWaveProject/CAM) (DataWave): neural network parameterizations of gravity waves in CAM.
* [ICON](https://www.icon-model.org/): FTorch is now actively used by ICON for ML coupling.
* ML convection parametrisations work best when causal relations are eliminated ([Heuer et al. (2023)](https://doi.org/10.1029/2024MS004398)).
* [GloSea6](https://www.metoffice.gov.uk/research/climate/seasonal-to-decadal/gpc-outlooks/user-guide/global-seasonal-forecasting-system-glosea6): Replaced BiCGStab solver with ML to speed up forecasting ([Park and Chung (2025)](https://doi.org/10.3390/atmos16010060)).
---
# Online training and differentiable programming
---
## Proposed FTorch online training workflow
<!---->
<!---->
<!---->

<!--
Now there's only two distinct steps, one of which is very minor.
-->
---
## Online training - pros and cons
* <span class="red">-Difficult to implement in most frameworks.</span>
* <span class="green">+Avoids saving large volumes of training data. </span>
* <span class="green">+Avoids need to convert between Python and Fortran data formats.</span>
* <span class="green">+Possibility to expand loss function scope to include downstream model code.</span>
---
## Online training in FTorch - `Optimizer` classes
* Expose `torch::optim::Adam`, `torch::optim::SGD`, etc.
* Expose their `zero_grad` and `step` methods.
```fortran
!> Type for holding a torch optimizer.
type torch_optim
type(c_ptr) :: p = c_null_ptr !! pointer to the optimizer in memory
contains
procedure :: step => torch_optim_step
procedure :: zero_grad => torch_optim_zero_grad
final :: torch_optim_delete
end type torch_optim
```
---
## Online training in FTorch - loss functions
* We haven't exposed any built-in loss functions yet.
* Implemented `torch_tensor_sum` and `torch_tensor_mean`, though.
---
## Online training - expanded loss function
Consider a loss function involving downstream model code, e.g.,
$$J(\theta)=\int_\Omega(u-u_{ML}(\theta))^2\;\mathrm{d}x,$$
where
* $u$: solution from the physical model
* $u_{ML}(\theta)$: solution from a hybrid model with some ML parameters $\theta$.
Computing $\mathrm{d}J/\mathrm{d}\theta$ requires differentiating Fortran code as well as ML code.
---
## Implementing AD in FTorch
* Expose `autograd` functionality from Torch.
* e.g., `requires_grad` argument and `backward` methods.
* Overload mathematical operators (`=`,`+`,`-`,`*`,`/`,`**`).
```fortran
interface operator (*)
module procedure torch_tensor_multiply
end interface
!> Overloads multiplication operator for two tensors.
function torch_tensor_multiply(tensor1, tensor2) result(output)
use, intrinsic :: iso_c_binding, only : c_associated
type(torch_tensor), intent(in) :: tensor1 !! First tensor to be multiplied
type(torch_tensor), intent(in) :: tensor2 !! Second tensor to be multiplied
type(torch_tensor) :: output !! Tensor to hold the product
! [C interface definition]
call torch_tensor_multiply_c(output%p, tensor1%p, tensor2%p)
end function torch_tensor_multiply
```
---
## Using AD in PyTorch
```python
"""Based on https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html."""
import torch
# Construct input tensors with requires_grad=True
a = torch.tensor([2.0, 3.0], requires_grad=True)
b = torch.tensor([6.0, 4.0], requires_grad=True)
# Compute some mathematical expression
Q = 3 * (a**3 - b * b / 3)
# Reverse mode
Q.backward(gradient=torch.ones_like(Q))
print(a.grad)
print(b.grad)
```
<!-- Note that external gradient is required unless scalar output. -->
---
## Using AD in FTorch
```fortran
use ftorch
type(torch_tensor) :: a, b, Q, multiplier, divisor, dQda, dQdb
real, dimension(2), target :: Q_arr, dQda_arr, dQdb_arr
! Construct input tensors with requires_grad=.true.
call torch_tensor_from_array(a, [2.0, 3.0], torch_kCPU, requires_grad=.true.)
call torch_tensor_from_array(b, [6.0, 4.0], torch_kCPU, requires_grad=.true.)
! Workaround for scalar multiplication and division using 0D tensors
call torch_tensor_from_array(multiplier, [3.0], torch_kCPU)
call torch_tensor_from_array(divisor, [3.0], torch_kCPU)
! Compute some mathematical expression
call torch_tensor_from_array(Q, Q_arr, torch_kCPU)
Q = multiplier * (a**3 - b * b / divisor)
! Reverse mode
call torch_tensor_backward(Q)
call torch_tensor_from_array(dQda, dQda_arr, torch_kCPU)
call torch_tensor_from_array(dQdb, dQdb_arr, torch_kCPU)
call torch_tensor_get_gradient(a, dQda)
call torch_tensor_get_gradient(b, dQdb)
print *, dQda_arr
print *, dQdb_arr
```
---
## Putting it together - optimisation in FTorch
$$\begin{bmatrix}f_1\\f_2\\f_3\\f_4\end{bmatrix}=\mathbf{f}(\mathbf{x};\mathbf{a})=\mathbf{a}\bullet\mathbf{x}\equiv\begin{bmatrix}a_1x_1\\a_2x_2\\a_3x_3\\a_4x_4\end{bmatrix}$$
For a given $\mathbf{x}$ and $\mathbf{b}$ and a loss function $\ell(\mathbf{a})=\overline{(\mathbf{f}(\mathbf{x};\mathbf{a})-\mathbf{b})^2}$, solve the optimisation problem
$$\min_{\mathbf{a}}\ell(\mathbf{a}).$$
---
## Putting it together - optimisation in FTorch
For $\mathbf{a}=\begin{bmatrix}1,1,1,1\end{bmatrix}^T$, $\mathbf{b}:=\begin{bmatrix}1,2,3,4\end{bmatrix}^T$, and $\mathbf{x}_0:=\mathbf{b}$.

In both cases we achieve $\mathbf{f}(\mathbf{x};\mathbf{a})=\begin{bmatrix}1,2,3,4\end{bmatrix}^T$.
---
## Differentiable programming
Lots of new possibilities beyond online training:
* Sensitivity analysis.
* Data assimilation.
* Uncertainty quantification.
* ODE/PDE-constrained optimisation.
* Goal-oriented error estimation.
---
## Case study: predicting timesteps for [UKCA](https://ukca.ac.uk)
* Implicit timestepping, quasi-Newton, full LU decomposition.
* For each time subinterval to be integrated:
* Start with $\Delta t=3600$.
* Try to integrate with the current timestep size.
* If *any grid-box* fails, half the step and try again.
<!--
* In many cases, the default large timestep is insufficient.
* This is particularly true for GPU port with large chunk size.
* Redundant calculation, which it would be good to avoid.
* Predict required timestep using a simple ML model.
-->

---
## Case study: [MiMA-ML](https://github.com/DataWaveProject/MiMA-machine-learning)

Reference: [Espinoza et al. (2022), *Machine Learning Gravity Wave Parameterization Generalizes to Capture the QBO and Response to Increased CO$_2$*](https://doi.org/10.1029/2022GL098174).
<!--
* MiMA = Model of an Idealized Moist Atmosphere
* MiMA captures key dynamical features of the stratosphere-troposphere system that depend critically on GWs at a resolution comparable to state-of-the-art stratosphere resolving global climate models
-->
<!--
Caption: Pressure-time profiles of the zonal mean zonal wind, averaged between 5°S and 5°N and smoothed with a 15-day low pass filter, show the behavior of the Quasi-Biennial Oscillation (QBO) in integrations of (a) the control version of Model of an Idealized Moist Atmosphere with the AD99 parameterization, (b) the model coupled with WaveNet, (c) a 4xCO2 integration with the AD99 parameterization, and (d) a 4xCO 2 integration coupled with WaveNet. Vertical dashed lines separate 5 years segments. The westerly (red) and easterly (blue) bands correspond to winds associated with opposite phases of the QBO. The QBO period and amplitudes are calculated using the transition time (TT) method. The dashed-horizontal line in each panel delineates the model level (≈10 hPa) where the TT method is used.
-->
<!-- AD99 is the model without ML (or AD). (Alexander and Dunkerton 1999 GWP.) It aims to capture the effect of non-orographic GW drag. -->
<!-- (a) 14, (b) 13, (c) 17 (d) 16 -->
<!-- Found online behavior diverges from offline for nearly-identical offline gravity waves models. Mansfield and Sheshadri (2024) -->
---
## Case study: snow density emulator for [Icepack](https://github.com/CicE-Consortium/Icepack/)
- MLP trained on [SnowModel](https://doi.org/10.1175/JHM548.1) high-res data generated for 5 Arctic regions
- Predicts snow density on top of sea ice given atmospheric forcing
- Working on the coupling with Icepack column physics sea ice model
Reference: [Prasad et al. (2024). *Modeling Snow on Sea Ice Using Physics-Guided Machine Learning*](https://doi.org/10.1017/eds.2024.40).
---
## Summary
* FTorch provides a Fortran interface for PyTorch.
* Torch `autograd` functionality exposed using `iso_c_binding`.
* Exposed tools for optimisation.
* Work in progress on setting up online ML training.
---
## Resources
* FTorch webpage: https://cambridge-iccs.github.io/FTorch.
* Atkinson et al., (2025). FTorch: a library for coupling PyTorch models to Fortran. Journal of Open Source Software, 10(107), 7602, https://doi.org/10.21105/joss.07602.

* [ICCS ML coupling workshop](https://cambridge-iccs.github.io/ml-coupling-workshop) - 3-4 September, Cambridge, U.K.
{"title":"DARC seminar","description":"Facilitating online training in Fortran-based climate models","contributors":"[{\"id\":\"033ac354-bcb8-4c50-8db3-75282f8d798a\",\"add\":20283,\"del\":3322,\"latestUpdatedAt\":null},{\"id\":\"cbe2c4f3-f044-40ee-93e0-264f1574f90f\",\"add\":587,\"del\":67}]"}