# SiraNet
Atmo Grand Est cherche à étendre ses prévisions fine échelle à toute la région Grand-Est afin de fournir des informations plus précises sur l’exposition journalière aux polluants de l’air. Ces prochaines prévisions fourniront ainsi au grand public une information à une résolution de quelques dizaines de mètres en tout point du territoire.
Actuellement, les concentrations en NO2, PM10, PM2.5 et O3 sont calculées et synthétisées au travers d’un indice de qualité de l’air. Ces simulations couvrent les principales villes du Grand-Est : Strasbourg, Mulhouse, Colmar, Metz, Nancy, Reims et Troyes. Cet indice est déterminé grâce à la chaîne de traitement Prevision’Air pour le journée en cours et les 2 jours à venir. Il peut être consulté sur le site internet de l’AASQA www.atmo-grandest.eu.
Le nouvel objectif est donc de fournir une information détaillée sur toute la région et cette volonté se manifeste via le projet PREIPA.

## 1. Contexte
La chaîne en cours de développement s’appuie sur les capacités d’un réseau de neurones convolutif pour réaliser de l’émulation de modèle. Dans notre cas, la référence sera le modèle Sirane développé par l’Ecole Centrale de Lyon.
Pour répondre à cet objectif, le réseau de neurones devra être correctement structuré et entraîné.
Pour information, un réseau de neurones fonctionne en plusieurs temps. Tout d’abord, il est entraîné sur des données d’apprentissage, ici des données Sirane, afin de générer des coefficients de pondération qui serviront par la suite de paramètres de base lors de l’exécution. Une fois le réseau de neurones prêt, il peut être exécuté autant de fois que nécessaire.

Comme indiqué sur le schéma ci-contre, la routine actuelle utilise une partie des données nécessaires au modèle Sirane pour s’entraîner ou réaliser une simulation. Pour faciliter leur utilisation au sein du réseau de neurones, ces entrées sont cadastrées sur une grille ayant les mêmes dimensions que la grille de sortie c’est-à-dire sur des mailles de 25m de côté.
L’entraînement d’un réseau de neurones peut être gourmand en ressources informatiques c’est pourquoi cette étape est réalisée sur un serveur équipé d’une carte GPU. Pour l’occasion, nous travaillons en partenariat avec le centre de calcul ROMEO de Reims.
## 2. Modèle
Il existe un grand nombres d’architectures de réseaux de neurones. Nous travaillons ici avec des réseaux convolutifs, les plus à même de prendre en compte la structure spatiale de données 2D, correspondant aux champs de surface de concentration des polluants.
Différentes familles de réseaux ont été implémentés dans le logiciel existant (CNN, auto-encodeur (variationnel ou non), U-Net). La version opérationnelle initiée fin 2021 utilise un CNN avec un nombre de paramètres à estimer de l’ordre du million. Après entraînement sur le jeu de données d’apprentissage, ces paramètres sont utilisés pour l’inférence sur les données d’entrées produites quotidiennement par Atmo Grand Est.

Les développements menés ont abouti à la construction d’une filière opérationnelle, qui ne nécessitera pas à l’avenir de modifications majeures.
La structure du logiciel a toutefois été conçu pour permettre d’améliorer ses performances, en se focalisant uniquement sur les blocs méthodologiques, à savoir l’architecture du réseau de neurones, ce qui ouvre la porte à de futurs travaux à vocation plus scientifique, faisant le lien entre deep learning et qualité de l’air
## 3. Résultats préliminaires
Les gains entre une approche réseaux de neurones et la modélisation classique peuvent se faire sur la base de trois indicateurs :
**Visuel**: la carte produite par le réseau est-elle réaliste et semblable à celle du modèle SIRANE? Les premiers éléments de réponse nous permettent d’affirmer que oui (voir ci-dessous à gauche l’animation O3 sur le 14/12/2019 de la carte SIRANE (à gauche) et du réseau (à droite).

**Statistique** : les prédictions du réseaux sont-elles conformes à des niveaux de performances attendus sur des métriques telles que corrélation et écart quadratique moyen (RMSE) ? Là encore les premiers résultats sont encourageants (voir ci-dessous à droite), avec des performances globalement très satisfaisantes. Des améliorations futures seront apportées en focalisant l’entraînement du réseau sur les zones à proximité des principaux axes routiers.

**Computationnel**: le temps de calcul est-il significativement réduit avec la nouvelle approche ? De manière certaine oui, la filière opérationnelle produit actuellement 24 cartes horaires pour un polluant en moins d’une minute (après lecture des données d’entrées). Des progrès restent à faire sur la lecture des données d’entrées du réseau, lourdes actuellement, pour réduire encore le temps de traitement total de la filière opérationnelle.
## 5. Actions en développement
Code couleur pour préciser le statut des tâches suivantes :
<span style="color:blue">**Done**</span>
<span style="color:green">**In progress**</span>
<span style="color:red">**To do**</span>
* <span style="color:red"> Améliorer modèle U-Net </span>
* <span style="color:red"> Automatisation des métriques d'évaluation à l'entraînement </span>
* <span style="color:red"> Structure du nouveau NetCDF d'entrée
</span>
```
netcdf Inputs_SIRANE {
dimensions:
y = 2400 ;
x = 3120 ;
Pollutant = 5 ;
time = 96 ;
variables:
double lineic(y, x, Pollutant) ;
lineic:_FillValue = NaN ;
string pollutant(Pollutant) ;
double x(x) ;
x:_FillValue = NaN ;
double y(y) ;
y:_FillValue = NaN ;
string Pollutant(Pollutant) ;
double surface(y, x, Pollutant) ;
surface:_FillValue = NaN ;
double ponctual(y, x, Pollutant) ;
ponctual:_FillValue = NaN ;
string time(time) ;
double background(time, Pollutant) ;
background:_FillValue = NaN ;
double U(time) ;
U:_FillValue = NaN ;
double Dir(time) ;
Dir:_FillValue = NaN ;
double Temp(time) ;
Temp:_FillValue = NaN ;
double SolRad(time) ;
SolRad:_FillValue = NaN ;
double Precip(time) ;
Precip:_FillValue = NaN ;
double modulation_lineic(time, Pollutant) ;
modulation_lineic:_FillValue = NaN ;
double modulation_surface(time, Pollutant) ;
modulation_surface:_FillValue = NaN ;
}
```
* <span style="color:red"> Dataloader associé
</span>
<style>
pre{
overflow-y: auto;
max-height: 200px;
}
</style>
<pre>
import numpy as np
import math
import numbers
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import xarray as xr
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from datetime import datetime
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def find_pad(sl, st, N):
k = np.floor(N/st)
if (N-k*st)!=0 :
pad = (k+1)*st + (sl-st) - N
else:
pad = 0
return int(pad/2), int(pad-int(pad/2))
class XrDataset(Dataset):
"""
torch Dataset based on an xarray file with on the fly slicing.
"""
def __init__(self, path, lvar, slice_win, dim_range=None, strides=None, decode=False, pol='NO2', resize_factor=1,
res=10, FMT ="%d/%m/%Y %H:%M"):
"""
:param path: xarray file
:param lvar: list of data variables to fetch
:param slice_win: window size for each dimension {<dim>: <win_size>...}
:param dim_range: Optional dimensions bounds for each dimension {<dim>: slice(<min>, <max>)...}
:param strides: strides on each dim while scanning the dataset {<dim>: <dim_stride>...}
:param decode: Whether to decode the time dim xarray
"""
super().__init__()
self.list_of_vars = lvar
self.res = res
print(path)
_ds = xr.open_mfdataset(path)
# select Pollutant
ipol = np.where(_ds.Pollutant.values == pol)[0][0]
_ds = _ds.isel(Pollutant=ipol)
# add coordinates
newtime = [ datetime.strptime(hh, FMT) for hh in _ds.time.values ]
_ds = _ds.assign_coords({"time": newtime})
_ds = _ds.assign_coords({"x": _ds.x.values})
_ds = _ds.assign_coords({"y": _ds.y.values})
if decode:
_ds.time.attrs["units"] = "seconds since "+_ds.time.values[0]
_ds = xr.decode_cf(_ds)
# reshape
if resize_factor!=1:
_ds = _ds.coarsen(x=resize_factor).mean(skipna=True).coarsen(y=resize_factor).mean(skipna=True)
# dimensions
self.ds = _ds.sel(**(dim_range or {}))
self.Nt = self.ds.time.shape[0]
self.Nx = self.ds.x.shape[0]
self.Ny = self.ds.y.shape[0]
#
# I) first padding x and y
pad_x = find_pad(slice_win['x'], strides['x'], self.Nx)
pad_y = find_pad(slice_win['y'], strides['y'], self.Ny)
# get additional data for patch center based reconstruction
dX = [pad_ *self.res for pad_ in pad_x]
dY = [pad_ *self.res for pad_ in pad_y]
dim_range_ = {
'x': slice(np.round(self.ds.x.min().item(),2)-dX[0], np.round(self.ds.x.max().item(),2)+dX[1]),
'y': slice(np.round(self.ds.y.min().item(),2)-dY[0], np.round(self.ds.y.max().item(),2)+dY[1]),
'time': dim_range['time']
}
self.ds = _ds.sel(**(dim_range_ or {}))
self.Nt = self.ds.time.shape[0]
self.Nx = self.ds.x.shape[0]
self.Ny = self.ds.y.shape[0]
# II) second padding x and y
pad_x = find_pad(slice_win['x'], strides['x'], self.Nx)
pad_y = find_pad(slice_win['y'], strides['y'], self.Ny)
# pad the dataset
dX = [pad_ *self.res for pad_ in pad_x]
dY = [pad_ *self.res for pad_ in pad_y]
pad_ = {'x':(pad_x[0],pad_x[1]),
'y':(pad_y[0],pad_y[1])}
self.ds = self.ds.pad(pad_,
mode='reflect')
self.Nx += np.sum(pad_x)
self.Ny += np.sum(pad_y)
# III) get lon-lat for the final reconstruction
dX = ((slice_win['x']-strides['x'])/2)*self.res
dY = ((slice_win['y']-strides['y'])/2)*self.res
dim_range_ = {
'x': slice(dim_range_['x'].start+dX, dim_range_['x'].stop-dX),
'y': slice(dim_range_['y'].start+dY, dim_range_['y'].stop-dY),
}
self.x = np.arange(dim_range_['x'].start, dim_range_['x'].stop + self.res, self.res)
self.y = np.arange(dim_range_['y'].start, dim_range_['y'].stop + self.res, self.res)
if isinstance(self.list_of_vars,list):
list_of_vars_ = self.list_of_vars
else:
list_of_vars_ = [self.list_of_vars]
for var in list_of_vars_:
print(var)
# according to the number of dimension reshape as (time, x, y)
# case 1: time only
if len(self.ds[var].shape)==1:
newval = np.broadcast_to(self.ds[var].values[:, np.newaxis, np.newaxis], (self.Nt, self.Ny, self.Nx))
self.ds.update({var: (('time','y','x'),np.nan_to_num(newval))})
# case 1: x,y only
elif len(self.ds[var].shape)==2:
newval = np.transpose(np.dstack([self.ds[var].values]*self.Nt),(2,0,1))
if var=="lineic":
newval = np.einsum('ijk,i->ijk',newval,self.ds.modulation_lineic)
if var=="surface":
newval = np.einsum('ijk,i->ijk',newval,self.ds.modulation_surface)
self.ds.update({var: (('time','y','x'),np.nan_to_num(newval))})
# case 2: t,x,y
else:
self.ds.update({var: (('time','y','x'),np.nan_to_num(self.ds[var].values))})
# returns a dataset if list_of_vars is a list, a dataarray else
if isinstance(self.list_of_vars,list):
self.ds = self.ds[self.list_of_vars]
else:
self.ds = self.ds[[self.list_of_vars]]
self.slice_win = slice_win
self.strides = strides or {}
self.ds_size = {
dim: max( int(np.ceil((self.ds.dims[dim] - slice_win[dim]) / self.strides.get(dim, 1))) + 1, 0)
for dim in slice_win
}
# remove ds_size elements for time dimension if no contiguous data
self.td = [ (self.ds.time.values[i+slice_win['time']-1]-self.ds.time.values[i]).astype('timedelta64[h]').astype(int) for i in range(self.ds_size['time']-slice_win['time']+1) if (i+slice_win['time']-1)<=(len(self.ds.time.values)-1) ]
self.id_ok = [i for i in range(self.ds_size['time']-slice_win['time']+1) if self.td[i] == (slice_win['time']-1)]
self.ds_size['time'] = len(self.id_ok)
def __del__(self):
self.ds.close()
def __len__(self):
size = 1
for v in self.ds_size.values():
size *= v
return size
def __getitem__(self, item):
sl = {
dim: slice(self.strides.get(dim, 1) * idx,
self.strides.get(dim, 1) * idx + self.slice_win[dim])
for dim, idx in zip(self.ds_size.keys(),
np.unravel_index(item, tuple(self.ds_size.values())))
}
# change slices for time dimension if no contiguous data
sl['time'] = slice(self.id_ok[sl['time'].start], self.id_ok[sl['time'].start] + self.slice_win['time'])
# returns a dataset if list_of_vars is a list, a dataarray else
if isinstance(self.list_of_vars,list):
_item = [ self.ds.isel(**sl)[var].data.astype(np.float32) for var in self.list_of_vars ]
else:
_item = self.ds.isel(**sl)[self.list_of_vars].data.astype(np.float32)
return _item
class SIRANetDataset(Dataset):
"""
Dataset for the SIRANet method:
an item contains a slice of OI, mask, and GT
does the preprocessing for the item
"""
def __init__(
self,
slice_win,
dim_range=None,
strides=None,
# INPUT
input_path='/project/projet321/data_sirane/Entrees_sirane.nc',
lin_var = 'lineic', # (time, x, y, Pollutant)
surf_var = 'surface', # (time, x, y, Pollutant)
ponct_var = 'ponctual', # (time, x, y, Pollutant)
temp_var = 'Temp', # time
prec_var = 'Precip', # time
winds_var = 'U', # time
windd_var = 'Dir', # time
ray_var = 'SolRad', # time
back_var = 'background', # fond
use_bati = False,
bati_var = "hmoy",
# OUTPUT
output_path='/project/projet321/data_sirane/Sorties_sirane*.nc',
pol_var='Concentration', # (time, x, y, Pollutant)
pol='NO2',
resize_factor=1,
res=10,
# OUTPUT COARSE
coarse_output_path=None,
):
super().__init__()
self.input_path = input_path
self.pol_var = pol_var
self.output_path = output_path
self.ray_var = ray_var
self.windd_var = windd_var
self.winds_var = winds_var
self.prec_var = prec_var
self.temp_var = temp_var
self.ponct_var = ponct_var
self.surf_var = surf_var
self.lin_var = lin_var
self.back_var = back_var
self.bati_var = bati_var
self.coarse_output_path = coarse_output_path
self.pol = pol
self.use_bati = use_bati
# output dataset
self.output_ds = XrDataset(output_path, pol_var, slice_win=slice_win, dim_range=dim_range,
strides=strides, decode=False, pol=pol, resize_factor=resize_factor, res=res, FMT='%Y%m%d%H')
# input dataset
self.input_ds = XrDataset(input_path, [surf_var, lin_var, ponct_var,
prec_var, temp_var, windd_var, winds_var, ray_var],
slice_win=slice_win, dim_range=dim_range,
strides=strides, decode=False, pol=pol, resize_factor=resize_factor, res=res)
if coarse_output_path is None:
self.back_ds = XrDataset(input_path, back_var, slice_win=slice_win, dim_range=dim_range,
strides=strides, decode=False, pol=pol, resize_factor=resize_factor, res=res)
else:
self.back_ds = XrDataset(coarse_output_path, "Concentration", slice_win=slice_win, dim_range=dim_range,
strides=strides, decode=False, pol=pol, resize_factor=resize_factor, res=res)
if use_bati == True:
self.bati_ds = XrDataset(input_path, bati_var, slice_win=slice_win, dim_range=dim_range,
strides=strides, decode=False, pol=pol, resize_factor=resize_factor, res=res)
self.norm_stats_input = None
self.norm_stats_output = None
def set_norm_stats_output(self, stats):
self.norm_stats_output = stats
def set_norm_stats_input(self, stats):
self.norm_stats_input = stats
def __len__(self):
return len(self.output_ds)
def coordXY(self):
return self.output_ds.x, self.output_ds.y
def __getitem__(self, item):
mean_input, std_input = self.norm_stats_input
mean_output, std_output = self.norm_stats_output
_output_item = (self.output_ds[item] - mean_output[self.pol]) / std_output[self.pol]
_input_item = np.concatenate( ( (self.input_ds[item][0] - mean_input['surf']) / std_input['surf'],
(self.input_ds[item][1] - mean_input['lin']) / std_input['lin'],
(self.input_ds[item][2] - mean_input['ponct']) / std_input['ponct'],
(self.back_ds[item] - mean_input['back']) / std_input['back'],
(self.input_ds[item][3] - mean_input['prec']) / std_input['prec'],
(self.input_ds[item][4] - mean_input['temp']) / std_input['temp'],
(self.input_ds[item][5] - mean_input['windd']) / std_input['windd'],
(self.input_ds[item][6] - mean_input['winds']) / std_input['winds'],
(self.input_ds[item][7] - mean_input['ray']) / std_input['ray'] ), axis=0 )
if self.use_bati == True:
_input_item = np.concatenate( (_input_item,
(self.bati_ds[item] - mean_input['bati']) / std_input['bati']),axis=0)
input_item = _input_item
output_item = _output_item
return input_item, output_item
class SIRANetDataModule(pl.LightningDataModule):
def __init__(
self,
slice_win,
dim_range=None,
strides=None,
train_slices=(slice('2019-01-13 01:00', "2019-07-31 00:00"),),
val_slices=(slice('2019-08-01 01:00', "2019-08-31 00:00"),),
test_slices=(slice('2019-10-31 01:00', "2019-12-31 00:00"),),
# INPUT
input_path='/project/projet321/data_sirane/Entrees_sirane.nc',
lin_var = 'lineic', # (time, x, y, Pollutant)
surf_var = 'surface', # (time, x, y, Pollutant)
ponct_var = 'ponctual', # (time, x, y, Pollutant)
temp_var = 'Temp', # time
prec_var = 'Precip', # time
winds_var = 'U', # time
windd_var = 'Dir', # time
ray_var = 'SolRad', # time
back_var = 'background', # fond
use_bati = False,
bati_var = "hmoy",
# OUTPUT
output_path='/project/projet321/data_sirane/Sorties_sirane*.nc',
pol_var='Concentration', # (time, x, y, Pollutant)
pol='NO2',
resize_factor=1,
res=10,
# OUTPUT COARSE
coarse_output_path=None,
dl_kwargs=None,
):
super().__init__()
self.pol = pol
self.use_bati = use_bati
self.resize_factor = resize_factor
self.res = res
self.dim_range = dim_range
self.slice_win = slice_win
self.strides = strides
self.dl_kwargs = {
**{'batch_size': 16, 'num_workers': 2, 'pin_memory': True},
**(dl_kwargs or {})
}
self.input_path = input_path
self.pol_var = pol_var
self.output_path = output_path
self.ray_var = ray_var
self.windd_var = windd_var
self.winds_var = winds_var
self.prec_var = prec_var
self.temp_var = temp_var
self.ponct_var = ponct_var
self.surf_var = surf_var
self.lin_var = lin_var
self.back_var = back_var
self.bati_var = bati_var
self.coarse_output_path = coarse_output_path
self.train_slices, self.test_slices, self.val_slices = train_slices, test_slices, val_slices
self.train_ds, self.val_ds, self.test_ds = None, None, None
self.norm_stats_input = None
self.norm_stats_output = None
def compute_norm_stats_input(self, ds):
mean = { 'surf': float(xr.concat([_ds.input_ds.ds[self.surf_var] for _ds in ds.datasets], dim='time').mean()),
'lin': float(xr.concat([_ds.input_ds.ds[self.lin_var] for _ds in ds.datasets], dim='time').mean()),
'ponct': float(xr.concat([_ds.input_ds.ds[self.ponct_var] for _ds in ds.datasets], dim='time').mean()),
'back': float(xr.concat([_ds.back_ds.ds[self.back_var] for _ds in ds.datasets], dim='time').mean()),
'prec': float(xr.concat([_ds.input_ds.ds[self.prec_var] for _ds in ds.datasets], dim='time').mean()),
'temp': float(xr.concat([_ds.input_ds.ds[self.temp_var] for _ds in ds.datasets], dim='time').mean()),
'windd': float(xr.concat([_ds.input_ds.ds[self.windd_var] for _ds in ds.datasets], dim='time').mean()),
'winds': float(xr.concat([_ds.input_ds.ds[self.winds_var] for _ds in ds.datasets], dim='time').mean()),
'ray': float(xr.concat([_ds.input_ds.ds[self.ray_var] for _ds in ds.datasets], dim='time').mean()) }
std = { 'surf': float(xr.concat([_ds.input_ds.ds[self.surf_var] for _ds in ds.datasets], dim='time').std()),
'lin': float(xr.concat([_ds.input_ds.ds[self.lin_var] for _ds in ds.datasets], dim='time').std()),
'ponct': float(xr.concat([_ds.input_ds.ds[self.ponct_var] for _ds in ds.datasets], dim='time').std()),
'back': float(xr.concat([_ds.back_ds.ds[self.back_var] for _ds in ds.datasets], dim='time').std()),
'prec': float(xr.concat([_ds.input_ds.ds[self.prec_var] for _ds in ds.datasets], dim='time').std()),
'temp': float(xr.concat([_ds.input_ds.ds[self.temp_var] for _ds in ds.datasets], dim='time').std()),
'windd': float(xr.concat([_ds.input_ds.ds[self.windd_var] for _ds in ds.datasets], dim='time').std()),
'winds': float(xr.concat([_ds.input_ds.ds[self.winds_var] for _ds in ds.datasets], dim='time').std()),
'ray': float(xr.concat([_ds.input_ds.ds[self.ray_var] for _ds in ds.datasets], dim='time').std()) }
if self.use_bati == True:
mean['bati'] = float(xr.concat([_ds.bati_ds.ds[self.bati_var] for _ds in ds.datasets], dim='time').mean())
std['bati'] = float(xr.concat([_ds.bati_ds.ds[self.bati_var] for _ds in ds.datasets], dim='time').std())
# correction on std if uniform variable
for key in std:
if std[key]==0.:
std[key]=1.
return mean, std
def set_norm_stats_input(self, ds, ns):
for _ds in ds.datasets:
_ds.set_norm_stats_input(ns)
def compute_norm_stats_output(self, ds):
mean = float(xr.concat([_ds.output_ds.ds[self.pol_var] for _ds in ds.datasets], dim='time').mean())
std = float(xr.concat([_ds.output_ds.ds[self.pol_var] for _ds in ds.datasets], dim='time').std())
return {self.pol: mean}, {self.pol: std}
def set_norm_stats_output(self, ds, ns):
for _ds in ds.datasets:
_ds.set_norm_stats_output(ns)
def get_domain_bounds(self, ds):
min_lon = round(np.min(np.concatenate([_ds.output_ds.ds['x'].values for _ds in ds.datasets])), 2)
max_lon = round(np.max(np.concatenate([_ds.output_ds.ds['x'].values for _ds in ds.datasets])), 2)
min_lat = round(np.min(np.concatenate([_ds.output_ds.ds['y'].values for _ds in ds.datasets])), 2)
max_lat = round(np.max(np.concatenate([_ds.output_ds.ds['y'].values for _ds in ds.datasets])), 2)
return min_lon, max_lon, min_lat, max_lat
def coordXY(self):
return self.test_ds.datasets[0].coordXY()
def get_domain_split(self):
return self.test_ds.datasets[0].output_ds.ds_size
def setup(self, stage=None):
self.train_ds, self.val_ds, self.test_ds = [
ConcatDataset(
[SIRANetDataset(
dim_range={**self.dim_range, **{'time': sl}},
strides=self.strides,
slice_win=self.slice_win,
input_path=self.input_path,
ray_var = self.ray_var,
windd_var = self.windd_var,
winds_var = self.winds_var,
prec_var = self.prec_var,
temp_var = self.temp_var,
ponct_var = self.ponct_var,
surf_var = self.surf_var,
lin_var = self.lin_var,
back_var = self.back_var,
use_bati = self.use_bati,
bati_var= self.bati_var,
output_path = self.output_path,
pol_var = self.pol_var,
pol = self.pol,
resize_factor=self.resize_factor,
res = self.res,
coarse_output_path = self.coarse_output_path,
) for sl in slices]
)
for slices in (self.train_slices, self.val_slices, self.test_slices)
]
self.norm_stats_input = self.compute_norm_stats_input(self.train_ds)
self.set_norm_stats_input(self.train_ds, self.norm_stats_input)
self.set_norm_stats_input(self.val_ds, self.norm_stats_input)
self.set_norm_stats_input(self.test_ds, self.norm_stats_input)
self.norm_stats_output = self.compute_norm_stats_output(self.train_ds)
self.set_norm_stats_output(self.train_ds, self.norm_stats_output)
self.set_norm_stats_output(self.val_ds, self.norm_stats_output)
self.set_norm_stats_output(self.test_ds, self.norm_stats_output)
self.bounding_box = self.get_domain_bounds(self.train_ds)
self.ds_size = self.get_domain_split()
def train_dataloader(self):
return DataLoader(self.train_ds, **self.dl_kwargs, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_ds, **self.dl_kwargs, shuffle=False)
def test_dataloader(self):
return DataLoader(self.test_ds, **self.dl_kwargs, shuffle=False)
</pre>
* <span style="color:red"> filière S2
* améliorer le modèle double échelle : diffusion model for LR + CNN-based diffusion for HR </span>
* <span style="color:red"> coder la nouvelle façon d'upsampler le modèle coarse (beaucoup plus rapide) </span>
* test filière S2
</span>