owned this note
owned this note
Published
Linked with GitHub
# Overview of the cGAN code
There are three main parts to getting the cGAN up and running for regional post-processing of global Numerical Weather Prediction (NWP) forecasts. For simplicity these are modularised into sub-directories with individual instructions contained within. When training and running the cGAN, we recommend visiting and following instructions from the sub-directories in the following order:
1) [data](https://github.com/snath-xoc/cGAN_tutorial/tree/main/data): Loading data and creating tfrecords for training.
2) [model](https://github.com/snath-xoc/cGAN_tutorial/tree/main/model): Setting up the model architecture and training the model.
3) [scripts](https://github.com/snath-xoc/cGAN_tutorial/tree/main/scripts): Generating forecasts.
Additionally, sub-directories [evaluation](https://github.com/snath-xoc/cGAN_tutorial/tree/main/evaluation) and [config](https://github.com/snath-xoc/cGAN_tutorial/tree/main/config) contain evaluation scripts and the necessary configuration files for setting data paths and model architecture.
# Data
## Overview of the data loading workflow
The figure below provides an overview of the data loading workflow:

The key files within the directories in order of importance are:
1) [data.py](https://github.com/snath-xoc/cGAN_tutorial/blob/main/data/data.py) where:
a) Individual forecast variables are loaded under `load_fcst`, and all variables are loaded by calling `load_fcst_stack`.
b) constants (topography and land-sea mask) are loaded under `load_hires_constants`.
c) truth data is loaded under `load_truth_and_mask`.
d) all three are loaded in for a given date and time in `load_fcst_truth_batch`.
2) [data_generator.py](https://github.com/snath-xoc/cGAN_tutorial/blob/main/data/data_generator.py) which has a ```DataGenerator``` class that runs `load_fcst_truth_batch` at different dates and times through iterative `__getitem__` calls. This is an important function as it does not load in all data to memory at once but (in Nishadh's words) allows streaming.
3) [tfrecords_generator.py](https://github.com/snath-xoc/cGAN_tutorial/blob/main/data/tfrecords_generator.py) which:
a) creates tfrecords by calling `write_data`.
b) during training it loads the batches from the tfrecords via the `create_mixed_dataset` that is called within its own `DataGenerator` function.
## Setting up data loading in your own machine
Before starting with anything, you need to check that all data paths are correctly set by going into the [config](https://github.com/snath-xoc/cGAN_tutorial/tree/main/config) sub-directory and adjusting the paths by:
1) Making sure the paths to read-in data (`TRUTH_PATH`, `FCST_PATH` and `CONSTANTS_PATH`) as well as write tfrecords to (`tfrecords_path`) in [config/data_paths.yaml](https://github.com/snath-xoc/cGAN_tutorial/blob/main/config/data_paths.yaml) are correct.
2) Making sure that the correct `data_path` option is set in [config/local_config.yaml](https://github.com/snath-xoc/cGAN_tutorial/blob/main/config/local_config.yaml).
Creating the training data needed to train the cGAN, require the following steps:
1) Setting up the data generator and visualising data to make sure it is loaded in correctly.
2) Creating the forecast normalisation constants consisting of the min, max, mean and standard deviation of each variable considered. This is necessary as for fitting any AI model, values need to be normalised to between 0 and 1.
3) Creating the tfrecord files
An example notebook is provided under the example_notebooks directory called [create_tfrecords.ipynb](https://github.com/snath-xoc/cGAN_tutorial/blob/main/example_notebooks/create_tfrecords.ipynb) which allows us to follow these steps.
# Model
## Model configuration
The cGAN implementation within this code base has several possible configurations which can be set in [config/config.yaml](https://github.com/snath-xoc/cGAN_tutorial/blob/main/config/config.yaml), with the options listed in subsections below.
### GENERAL
Under the ```GENERAL``` section one can choose the option ```mode``` as:
- ```"det"```: [deterministic](https://github.com/snath-xoc/cGAN_tutorial/blob/main/model/deterministic.py) generator trained without the discriminator
- ```"GAN"```: implements a Wasserstein Conditional Generative Adversarial Network with Gradient Penalty [(WGANGP)](https://github.com/snath-xoc/cGAN_tutorial/blob/main/model/gan.py).
- ```"VAEGAN"```: [Variational Auto-Encoder WGANGP](https://github.com/snath-xoc/cGAN_tutorial/blob/main/model/vaegantrain.py): Same as above but with an additional auto-encoder term before the cGAN.
As well as the option ```problem_type```:
- ```"normal"```: normal problem of post-processing/downscaling.
- ```"autocoarsen"```: coarsen high resolution forecast input before post-processing/downscaling.
### MODEL
There are different ```MODEL``` options from which the ```architecture``` can be set, with options:
- ```"normal"```: implements three residual blocks from which outputs are concatenated with constant features of topography and land-sea mask and passed through another three residual blocks.
- ```"forceconv"```: adds an initial pass through a 2-D convolutional layer ([see this article for a good explainer on convolutional arithmetic](https://arxiv.org/pdf/1603.07285)) with a 1x1 kernel before passing through the residual blocks.
- ```"forceconv-long"```: Adds an additional 3 residual blocks to the initial 3 residual blocks under ```"forceconv"```.
### GENERATOR
Settings for the generator that can be adjusted to improve cGAN training are:
- ```filters_gen```: generator network width
- ```noise_channels```: number of noise channels to have
- ```learning_rate_gen```: learning steps to take in gradient descent, if training blows up, decrease this.
Discriminator settings under the section```DISCRIMINATOR``` also enable adjustment of the ```filters_disc``` and ```learning_rate_disc```.
### TRAIN, VAL and EVAL
Where most importantly, one can set
- ```train_years``` and ```val_years``` used for training and validation respectively
- ```training_weights```: frequency to sample from each bin used in tfrecord creation
- ```num_samples```: total generator training samples
- ```steps_per_checkpoint```: number of batches per checkpoint save
- ```batch_size```: size of batches used during generator training
- ```ensemble_size```: size of ensemble for content loss; use null to turn off
- ```CL_type```: type of content loss (additional loss on top of wasserstein loss see Harris et al. (2022)), options are:
- ```'CRPS'```: Continuous Ranked Probability Score
- ```'CRPS_phys'```: CRPS using actual rainfall values by first applying the inverse log transformation
- ```'ensmeanMSE'```: ensemble mean Mean Squared Error
- ```'ensmeanMSE_phys'```: ensmeanMSE using actual rainfall values by first applying the inverse log transformation
- ```content_loss_weight```: weighting of content loss when adding it to wasserstein loss
## Default cGAN set up
By default we use the ```GAN``` with ```forceconv-long``` which has an architecture depicted below:

Where a residual block follows the architecture shown below:

Such that the ```filters_gen=128``` and ```filters_disc=512```.
## Step-by-step instructions on training the cGAN
For training and evaluating the cGAN we follow four simple steps:
1) Set up training and evaluation data by calling ```setup_data``` in [setupdata.py](https://github.com/snath-xoc/cGAN_tutorial/blob/main/setupdata.py) which:
a) Calls ```setup_batch_gen``` to load in the tfrecords and sample them according to the specified ```training_weights```.
b) Calls ```setup_full_image_dataset``` which loads in full images for ```val_years``` to evaluate the cGAN over.
2) Set up the model according to the configuration from [config/config.yaml](https://github.com/snath-xoc/cGAN_tutorial/blob/main/config/config.yaml) by calling ```setup_model``` from [setupmodel.py](https://github.com/snath-xoc/cGAN_tutorial/blob/main/setupmodel.py).
3) Starts training the model by calling ```train_model``` from [model/train.py](https://github.com/snath-xoc/cGAN_tutorial/blob/main/model/train.py).
4) Evaluates the model across the multiple saved checkpoints by calling ```evaluate_multiple_checkpoints``` from [evaluation/evaluation.py](https://github.com/snath-xoc/cGAN_tutorial/blob/main/evaluation/evaluation.py).
An example notebook is provided under the example_notebooks directory called [train_cgan.ipynb](https://github.com/snath-xoc/cGAN_tutorial/blob/main/example_notebooks/train_cgab.ipynb) which allows us to follow these steps for one training epoch. To fully train the cGAN it is better to follow the command line instructions:
```
python main.py --config path/to/config_file.yaml
```
There are a number of options you can use at this point. These will
evaluate your model after it has finished training:
- `--evaluate` to run checkpoint evaluation (CRPS, rank calculations, RMSE, RALSD, etc.)
- `--plot_ranks` will plot rank histograms (requires `--evaluate`)
If you choose to run `--evaluate`, you must also specify if you want
to do this for all model checkpoints or just a selection. Do this using
- `--eval_full` (all model checkpoints)
- `--eval_short` (recommended; the final 1/3rd of model checkpoints)
- `--eval_blitz` (the final 4 model checkpoints)
Two things to note:
- These three options work well with the 100 checkpoints that we
have been working with. If this changes, you may want to update
them accordingly.
- Calculating everything, for all model iterations, will take a long
time. Possibly weeks. You have been warned.
As an example, to train a model and evaluate the last few model
checkpoints, you could run:
```
python main.py --config path/to/config_file.yaml --evaluate --eval_blitz --plot_ranks
```
If you've already trained your model, and you just want to run some
evaluation, use the --no_train flag, for example:
```
python main.py --config path/to/config_file.yaml --no_train --evaluate --eval_full
```
# Steps to setting up cGAN on the cloud:
1. Set up the VM instance called ```tfrecords-cgan-store-nka-t2```, make the main.tf in the folder and then run
```
terraform init
terraform validate
terraform plan --out=main.tfplan
terraform apply main.tfplan
```
(type yes when prompted)
2. Attach the SSD disk using terraform (make a new directory in same location called ssd_attach and save the new file attac_ssd.tf and navigate to the directory with attach_ssd.tf) and again type:
```
terraform init
terraform validate
terraform plan --out=attach_ssd.tfplan
terraform apply attach_ssd.tfplan
```
3. Log in to the created VM instance using
```
gcloud compute ssh tfrecords-cgan-store-nka-t2 --project=sewaa-416306 --zone=europe-west2-b
```
4. Check whether the disk is available by typing
```
lsblk
```
Which should return the disk path under ```/dev/<disk-name>```
5. Before mounting the disk make sure the ownership is correct of the home directory
```
sudo chown -R $USER:$USER /home/<user_directory>
```
6. Mount the disk by typing
```
sudo mount /dev/<disk-name> /home/<user_directory>/
```
7. Make sure the ownership is specified again on the mounted directories
```
sudo chown -R $USER:$USER /home/<user_directory>
```
8. The disk should now be visible and if you go into the home directory and type ```ls```, you should be able to see several folders including ones called ```cGAN```, ```TRUTH```, ```FCST```, ```CONSTANTS``` and ```tfrecords```
9. Before running the script, we need to install the environment. The VM machine instance initialised should be a ```g2-standard-8``` and already have conda. We therefore need to run the following commands and make an environemtn file
```environment.yaml
name: tf216
channels:
- conda-forge
- nodefaults # keep builds reproducible
dependencies:
# Core runtime
- python=3.11 # pin the version if you need specific compatibility
- pip # pip must be listed before the pip section
# Conda‑installed scientific stack
- numpy==1.26.4
- h5py
- xarray
- netCDF4
- seaborn
- zarr
- xbatcher
- tqdm
- scipy
- xesmf
# Anything not available (or preferred) via conda goes here
- pip:
- tensorflow==2.16.1
- tf_keras==2.15.0
- pyyaml
- properscoring
```
commands to run
```
# create the environment
conda env create -f environment.yaml
# or: micromamba create -f environment.yaml
# activate it
conda activate tf216 # or: micromamba activate tf216
```
10. Due to legacy issues, we also need to run:
```
export TF_USE_LEGACY_KERAS=1
```
11. Navigate to the ```cGAN/shruti``` alll the data paths and configurations should be set if you check the files ```data_path.yaml``` and ```config.yaml```, feel free to adjust.
12. Now we can run the cGAN
```
python3 main.py --config config.yaml --evaluate --eval_short
```
Note this will probably:
A. Throw an error because of a bug in the tensorflow package, we will need to navigate to the exact file and fix it
B. Ask for a weights and biases login, it is best to configure a W&B account, or otherwise we can use mine that is already set up.
# gcloud commands that are useful
```
gcloud compute ssh <instance_name> --project=<project_name> --zone=<zone>
```
project name = sewaa-416306
zone = europe-west2-b
```
gloud compute scp --recurse
```
# Terraform state files configurations
```main.tf
provider "google" {
credentials = file("../key-nka-terraform-access.json")
project = "sewaa-416306"
region = "europe-west2"
}
resource "google_compute_instance" "vm_instance" {
name = "tfrecords-cgan-store-nka-t2"
machine_type = "g2-standard-8"
zone = "europe-west2-b"
boot_disk {
initialize_params {
image = "projects/deeplearning-platform-release/global/images/tf-ent-2-15-cu121-v20240417-debian-11-py310"
}
}
scheduling {
on_host_maintenance = "TERMINATE"
}
network_interface {
network = "default"
access_config {
// Ephemeral IP
}
}
metadata = {
enable-oslogin = "TRUE"
}
service_account {
scopes = ["userinfo-email", "compute-ro", "storage-ro"]
}
}
```
# Attaching SSD disk terraform config
```ssd_attach/attach_ssd.tf
provider "google" {
credentials = file("../key-nka-terraform-access.json")
project = "sewaa-416306"
region = "europe-west2"
}
resource "google_compute_attached_disk" "attach_ssd" {
disk = "t1-cgan-ssd-disk-large"
instance = "tfrecords-cgan-store-nka-t2"
zone = "europe-west2-b"
}
```
# FAQ
# Good to know
Weights and biases: https://wandb.ai/site/
## Requirements text from previous success run
```
fonttools==4.58.0
Brotli==1.1.0
annotated-types==0.7.0
narwhals==1.41.0
docker-pycreds==0.4.0
tzdata==2025.2
rich==14.0.0
pyshp==2.3.1
cftime==1.6.4
click==8.2.1
sortedcontainers==2.4.0
PySocks==1.7.1
protobuf==4.25.7
zstandard==0.23.0
dask==2025.5.1
platformdirs==4.3.8
tensorboard==2.16.2
tensorboard-data-server==0.7.2
cached-property==1.5.2
setproctitle==1.3.6
tf-keras==2.15.0
tblib==3.1.0
crc32c==2.7.1
ml-dtypes==0.3.2
tqdm==4.67.1
MarkupSafe==3.0.2
psutil==7.0.0
contourpy==1.3.2
Pygments==2.19.1
certifi==2025.4.26
netCDF4==1.7.2
cytoolz==1.0.1
astunparse==1.6.3
hyperframe==6.1.0
six==1.17.0
wheel==0.45.1
google-pasta==0.2.0
numpy==1.26.4
gitdb==4.0.12
partd==1.4.2
python-dateutil==2.9.0.post0
smmap==5.0.2
keras==3.10.0
esmpy==8.8.1
pandas==2.2.3
h5py==3.13.0
requests==2.32.3
numcodecs==0.16.1
h2==4.2.0
unicodedata2==16.0.0
pydantic_core==2.33.2
absl-py==2.3.0
charset-normalizer==3.4.2
namex==0.1.0
termcolor==3.1.0
cf_xarray==0.10.5
xarray==2025.4.0
typing-inspection==0.4.1
matplotlib==3.10.3
zict==3.0.0
markdown-it-py==3.0.0
xesmf==0.8.10
gast==0.6.0
locket==1.0.0
toolz==1.0.0
Werkzeug==3.1.3
donfig==0.8.1.post1
cycler==0.12.1
flatbuffers==25.2.10
pyparsing==3.2.3
pillow==11.2.1
opt_einsum==3.4.0
pycparser==2.22
pyarrow==20.0.0
zarr==3.0.8
importlib_metadata==8.7.0
scipy==1.15.2
msgpack==1.1.0
libclang==18.1.1
Cartopy==0.24.0
PyYAML==6.0.2
kiwisolver==1.4.7
pip==25.1.1
bokeh==3.7.3
munkres==1.1.4
Jinja2==3.1.6
grpcio==1.71.0
mdurl==0.1.2
pydantic==2.11.5
llvmlite==0.44.0
fsspec==2025.5.1
xyzservices==2025.4.0
xbatcher==0.4.0
hpack==4.1.0
optree==0.15.0
tornado==6.5.1
zope.interface==7.2
urllib3==2.4.0
distributed==2025.5.1
patsy==1.0.1
lz4==4.4.4
cffi==1.17.1
statsmodels==0.14.4
DateTime==5.5
tensorflow-io-gcs-filesystem==0.37.1
pytz==2025.2
wandb==0.19.11
sparse==0.17.0
typing_extensions==4.13.2
tensorflow==2.16.1
seaborn==0.13.2
GitPython==3.1.44
idna==3.10
packaging==25.0
Deprecated==1.2.18
properscoring==0.1
cloudpickle==3.1.1
zipp==3.22.0
numba==0.61.2
pyproj==3.7.1
Markdown==3.8
sentry-sdk==2.29.1
shapely==2.1.1
setuptools==80.8.0
wrapt==1.17.2
```
## ENV errors 20250925
```
AttributeError: module 'tensorflow._api.v2.compat.v2.__internal__' has no attribute 'register_load_context_function'. Did you mean: 'register_call_context_function'?
```
# Steps for tf_216 set up
```
micromamba create -n tf_216 python==3.11 numpy==1.26.4
pip install tensorflow==2.16.1 tf_keras==2.15
micromamba install xarray netCDF4 xesmf dask zarr=2.18.7
micromamba install matplotlib cartopy shapely regionmask
pip install wandb tqdm seaborn properscoring xbatcher
```
The do
```
export TF_USE_LEGACY_KERAS=1
```
and navigate to the following file
```
envs/tf_216/lib/python3.11/site-packages/tf_keras/src/saving/legacy/saved_model/load_context.py
```
and change
```
register_load_context_function
```
to
```
register_call_context_function
```
as
```
tf.__internal__.register_call_context_function(in_load_context)
```