# SiraNet Atmo Grand Est cherche à étendre ses prévisions fine échelle à toute la région Grand-Est afin de fournir des informations plus précises sur l’exposition journalière aux polluants de l’air. Ces prochaines prévisions fourniront ainsi au grand public une information à une résolution de quelques dizaines de mètres en tout point du territoire. Actuellement, les concentrations en NO2, PM10, PM2.5 et O3 sont calculées et synthétisées au travers d’un indice de qualité de l’air. Ces simulations couvrent les principales villes du Grand-Est : Strasbourg, Mulhouse, Colmar, Metz, Nancy, Reims et Troyes. Cet indice est déterminé grâce à la chaîne de traitement Prevision’Air pour le journée en cours et les 2 jours à venir. Il peut être consulté sur le site internet de l’AASQA www.atmo-grandest.eu. Le nouvel objectif est donc de fournir une information détaillée sur toute la région et cette volonté se manifeste via le projet PREIPA. ![](https://i.imgur.com/ruqXB3H.png) ## 1. Contexte La chaîne en cours de développement s’appuie sur les capacités d’un réseau de neurones convolutif pour réaliser de l’émulation de modèle. Dans notre cas, la référence sera le modèle Sirane développé par l’Ecole Centrale de Lyon. Pour répondre à cet objectif, le réseau de neurones devra être correctement structuré et entraîné. Pour information, un réseau de neurones fonctionne en plusieurs temps. Tout d’abord, il est entraîné sur des données d’apprentissage, ici des données Sirane, afin de générer des coefficients de pondération qui serviront par la suite de paramètres de base lors de l’exécution. Une fois le réseau de neurones prêt, il peut être exécuté autant de fois que nécessaire. ![](https://i.imgur.com/3IpjziH.png) Comme indiqué sur le schéma ci-contre, la routine actuelle utilise une partie des données nécessaires au modèle Sirane pour s’entraîner ou réaliser une simulation. Pour faciliter leur utilisation au sein du réseau de neurones, ces entrées sont cadastrées sur une grille ayant les mêmes dimensions que la grille de sortie c’est-à-dire sur des mailles de 25m de côté. L’entraînement d’un réseau de neurones peut être gourmand en ressources informatiques c’est pourquoi cette étape est réalisée sur un serveur équipé d’une carte GPU. Pour l’occasion, nous travaillons en partenariat avec le centre de calcul ROMEO de Reims. ## 2. Modèle Il existe un grand nombres d’architectures de réseaux de neurones. Nous travaillons ici avec des réseaux convolutifs, les plus à même de prendre en compte la structure spatiale de données 2D, correspondant aux champs de surface de concentration des polluants. Différentes familles de réseaux ont été implémentés dans le logiciel existant (CNN, auto-encodeur (variationnel ou non), U-Net). La version opérationnelle initiée fin 2021 utilise un CNN avec un nombre de paramètres à estimer de l’ordre du million. Après entraînement sur le jeu de données d’apprentissage, ces paramètres sont utilisés pour l’inférence sur les données d’entrées produites quotidiennement par Atmo Grand Est. ![](https://i.imgur.com/l6MzAkH.png) Les développements menés ont abouti à la construction d’une filière opérationnelle, qui ne nécessitera pas à l’avenir de modifications majeures. La structure du logiciel a toutefois été conçu pour permettre d’améliorer ses performances, en se focalisant uniquement sur les blocs méthodologiques, à savoir l’architecture du réseau de neurones, ce qui ouvre la porte à de futurs travaux à vocation plus scientifique, faisant le lien entre deep learning et qualité de l’air ## 3. Résultats préliminaires Les gains entre une approche réseaux de neurones et la modélisation classique peuvent se faire sur la base de trois indicateurs : **Visuel**: la carte produite par le réseau est-elle réaliste et semblable à celle du modèle SIRANE? Les premiers éléments de réponse nous permettent d’affirmer que oui (voir ci-dessous à gauche l’animation O3 sur le 14/12/2019 de la carte SIRANE (à gauche) et du réseau (à droite). ![](https://i.imgur.com/aubHGVi.gif) **Statistique** : les prédictions du réseaux sont-elles conformes à des niveaux de performances attendus sur des métriques telles que corrélation et écart quadratique moyen (RMSE) ? Là encore les premiers résultats sont encourageants (voir ci-dessous à droite), avec des performances globalement très satisfaisantes. Des améliorations futures seront apportées en focalisant l’entraînement du réseau sur les zones à proximité des principaux axes routiers. ![](https://i.imgur.com/UxIAAAs.png) **Computationnel**: le temps de calcul est-il significativement réduit avec la nouvelle approche ? De manière certaine oui, la filière opérationnelle produit actuellement 24 cartes horaires pour un polluant en moins d’une minute (après lecture des données d’entrées). Des progrès restent à faire sur la lecture des données d’entrées du réseau, lourdes actuellement, pour réduire encore le temps de traitement total de la filière opérationnelle. ## 5. Actions en développement Code couleur pour préciser le statut des tâches suivantes : <span style="color:blue">**Done**</span> <span style="color:green">**In progress**</span> <span style="color:red">**To do**</span> * <span style="color:red"> Améliorer modèle U-Net </span> * <span style="color:red"> Automatisation des métriques d'évaluation à l'entraînement </span> * <span style="color:red"> Structure du nouveau NetCDF d'entrée </span> ``` netcdf Inputs_SIRANE { dimensions: y = 2400 ; x = 3120 ; Pollutant = 5 ; time = 96 ; variables: double lineic(y, x, Pollutant) ; lineic:_FillValue = NaN ; string pollutant(Pollutant) ; double x(x) ; x:_FillValue = NaN ; double y(y) ; y:_FillValue = NaN ; string Pollutant(Pollutant) ; double surface(y, x, Pollutant) ; surface:_FillValue = NaN ; double ponctual(y, x, Pollutant) ; ponctual:_FillValue = NaN ; string time(time) ; double background(time, Pollutant) ; background:_FillValue = NaN ; double U(time) ; U:_FillValue = NaN ; double Dir(time) ; Dir:_FillValue = NaN ; double Temp(time) ; Temp:_FillValue = NaN ; double SolRad(time) ; SolRad:_FillValue = NaN ; double Precip(time) ; Precip:_FillValue = NaN ; double modulation_lineic(time, Pollutant) ; modulation_lineic:_FillValue = NaN ; double modulation_surface(time, Pollutant) ; modulation_surface:_FillValue = NaN ; } ``` * <span style="color:red"> Dataloader associé </span> <style> pre{ overflow-y: auto; max-height: 200px; } </style> <pre> import numpy as np import math import numbers import matplotlib.pyplot as plt import pytorch_lightning as pl import xarray as xr import torch import torch.nn.functional as F from torch.utils.data import Dataset, ConcatDataset, DataLoader from datetime import datetime device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def find_pad(sl, st, N): k = np.floor(N/st) if (N-k*st)!=0 : pad = (k+1)*st + (sl-st) - N else: pad = 0 return int(pad/2), int(pad-int(pad/2)) class XrDataset(Dataset): """ torch Dataset based on an xarray file with on the fly slicing. """ def __init__(self, path, lvar, slice_win, dim_range=None, strides=None, decode=False, pol='NO2', resize_factor=1, res=10, FMT ="%d/%m/%Y %H:%M"): """ :param path: xarray file :param lvar: list of data variables to fetch :param slice_win: window size for each dimension {<dim>: <win_size>...} :param dim_range: Optional dimensions bounds for each dimension {<dim>: slice(<min>, <max>)...} :param strides: strides on each dim while scanning the dataset {<dim>: <dim_stride>...} :param decode: Whether to decode the time dim xarray """ super().__init__() self.list_of_vars = lvar self.res = res print(path) _ds = xr.open_mfdataset(path) # select Pollutant ipol = np.where(_ds.Pollutant.values == pol)[0][0] _ds = _ds.isel(Pollutant=ipol) # add coordinates newtime = [ datetime.strptime(hh, FMT) for hh in _ds.time.values ] _ds = _ds.assign_coords({"time": newtime}) _ds = _ds.assign_coords({"x": _ds.x.values}) _ds = _ds.assign_coords({"y": _ds.y.values}) if decode: _ds.time.attrs["units"] = "seconds since "+_ds.time.values[0] _ds = xr.decode_cf(_ds) # reshape if resize_factor!=1: _ds = _ds.coarsen(x=resize_factor).mean(skipna=True).coarsen(y=resize_factor).mean(skipna=True) # dimensions self.ds = _ds.sel(**(dim_range or {})) self.Nt = self.ds.time.shape[0] self.Nx = self.ds.x.shape[0] self.Ny = self.ds.y.shape[0] # # I) first padding x and y pad_x = find_pad(slice_win['x'], strides['x'], self.Nx) pad_y = find_pad(slice_win['y'], strides['y'], self.Ny) # get additional data for patch center based reconstruction dX = [pad_ *self.res for pad_ in pad_x] dY = [pad_ *self.res for pad_ in pad_y] dim_range_ = { 'x': slice(np.round(self.ds.x.min().item(),2)-dX[0], np.round(self.ds.x.max().item(),2)+dX[1]), 'y': slice(np.round(self.ds.y.min().item(),2)-dY[0], np.round(self.ds.y.max().item(),2)+dY[1]), 'time': dim_range['time'] } self.ds = _ds.sel(**(dim_range_ or {})) self.Nt = self.ds.time.shape[0] self.Nx = self.ds.x.shape[0] self.Ny = self.ds.y.shape[0] # II) second padding x and y pad_x = find_pad(slice_win['x'], strides['x'], self.Nx) pad_y = find_pad(slice_win['y'], strides['y'], self.Ny) # pad the dataset dX = [pad_ *self.res for pad_ in pad_x] dY = [pad_ *self.res for pad_ in pad_y] pad_ = {'x':(pad_x[0],pad_x[1]), 'y':(pad_y[0],pad_y[1])} self.ds = self.ds.pad(pad_, mode='reflect') self.Nx += np.sum(pad_x) self.Ny += np.sum(pad_y) # III) get lon-lat for the final reconstruction dX = ((slice_win['x']-strides['x'])/2)*self.res dY = ((slice_win['y']-strides['y'])/2)*self.res dim_range_ = { 'x': slice(dim_range_['x'].start+dX, dim_range_['x'].stop-dX), 'y': slice(dim_range_['y'].start+dY, dim_range_['y'].stop-dY), } self.x = np.arange(dim_range_['x'].start, dim_range_['x'].stop + self.res, self.res) self.y = np.arange(dim_range_['y'].start, dim_range_['y'].stop + self.res, self.res) if isinstance(self.list_of_vars,list): list_of_vars_ = self.list_of_vars else: list_of_vars_ = [self.list_of_vars] for var in list_of_vars_: print(var) # according to the number of dimension reshape as (time, x, y) # case 1: time only if len(self.ds[var].shape)==1: newval = np.broadcast_to(self.ds[var].values[:, np.newaxis, np.newaxis], (self.Nt, self.Ny, self.Nx)) self.ds.update({var: (('time','y','x'),np.nan_to_num(newval))}) # case 1: x,y only elif len(self.ds[var].shape)==2: newval = np.transpose(np.dstack([self.ds[var].values]*self.Nt),(2,0,1)) if var=="lineic": newval = np.einsum('ijk,i->ijk',newval,self.ds.modulation_lineic) if var=="surface": newval = np.einsum('ijk,i->ijk',newval,self.ds.modulation_surface) self.ds.update({var: (('time','y','x'),np.nan_to_num(newval))}) # case 2: t,x,y else: self.ds.update({var: (('time','y','x'),np.nan_to_num(self.ds[var].values))}) # returns a dataset if list_of_vars is a list, a dataarray else if isinstance(self.list_of_vars,list): self.ds = self.ds[self.list_of_vars] else: self.ds = self.ds[[self.list_of_vars]] self.slice_win = slice_win self.strides = strides or {} self.ds_size = { dim: max( int(np.ceil((self.ds.dims[dim] - slice_win[dim]) / self.strides.get(dim, 1))) + 1, 0) for dim in slice_win } # remove ds_size elements for time dimension if no contiguous data self.td = [ (self.ds.time.values[i+slice_win['time']-1]-self.ds.time.values[i]).astype('timedelta64[h]').astype(int) for i in range(self.ds_size['time']-slice_win['time']+1) if (i+slice_win['time']-1)<=(len(self.ds.time.values)-1) ] self.id_ok = [i for i in range(self.ds_size['time']-slice_win['time']+1) if self.td[i] == (slice_win['time']-1)] self.ds_size['time'] = len(self.id_ok) def __del__(self): self.ds.close() def __len__(self): size = 1 for v in self.ds_size.values(): size *= v return size def __getitem__(self, item): sl = { dim: slice(self.strides.get(dim, 1) * idx, self.strides.get(dim, 1) * idx + self.slice_win[dim]) for dim, idx in zip(self.ds_size.keys(), np.unravel_index(item, tuple(self.ds_size.values()))) } # change slices for time dimension if no contiguous data sl['time'] = slice(self.id_ok[sl['time'].start], self.id_ok[sl['time'].start] + self.slice_win['time']) # returns a dataset if list_of_vars is a list, a dataarray else if isinstance(self.list_of_vars,list): _item = [ self.ds.isel(**sl)[var].data.astype(np.float32) for var in self.list_of_vars ] else: _item = self.ds.isel(**sl)[self.list_of_vars].data.astype(np.float32) return _item class SIRANetDataset(Dataset): """ Dataset for the SIRANet method: an item contains a slice of OI, mask, and GT does the preprocessing for the item """ def __init__( self, slice_win, dim_range=None, strides=None, # INPUT input_path='/project/projet321/data_sirane/Entrees_sirane.nc', lin_var = 'lineic', # (time, x, y, Pollutant) surf_var = 'surface', # (time, x, y, Pollutant) ponct_var = 'ponctual', # (time, x, y, Pollutant) temp_var = 'Temp', # time prec_var = 'Precip', # time winds_var = 'U', # time windd_var = 'Dir', # time ray_var = 'SolRad', # time back_var = 'background', # fond use_bati = False, bati_var = "hmoy", # OUTPUT output_path='/project/projet321/data_sirane/Sorties_sirane*.nc', pol_var='Concentration', # (time, x, y, Pollutant) pol='NO2', resize_factor=1, res=10, # OUTPUT COARSE coarse_output_path=None, ): super().__init__() self.input_path = input_path self.pol_var = pol_var self.output_path = output_path self.ray_var = ray_var self.windd_var = windd_var self.winds_var = winds_var self.prec_var = prec_var self.temp_var = temp_var self.ponct_var = ponct_var self.surf_var = surf_var self.lin_var = lin_var self.back_var = back_var self.bati_var = bati_var self.coarse_output_path = coarse_output_path self.pol = pol self.use_bati = use_bati # output dataset self.output_ds = XrDataset(output_path, pol_var, slice_win=slice_win, dim_range=dim_range, strides=strides, decode=False, pol=pol, resize_factor=resize_factor, res=res, FMT='%Y%m%d%H') # input dataset self.input_ds = XrDataset(input_path, [surf_var, lin_var, ponct_var, prec_var, temp_var, windd_var, winds_var, ray_var], slice_win=slice_win, dim_range=dim_range, strides=strides, decode=False, pol=pol, resize_factor=resize_factor, res=res) if coarse_output_path is None: self.back_ds = XrDataset(input_path, back_var, slice_win=slice_win, dim_range=dim_range, strides=strides, decode=False, pol=pol, resize_factor=resize_factor, res=res) else: self.back_ds = XrDataset(coarse_output_path, "Concentration", slice_win=slice_win, dim_range=dim_range, strides=strides, decode=False, pol=pol, resize_factor=resize_factor, res=res) if use_bati == True: self.bati_ds = XrDataset(input_path, bati_var, slice_win=slice_win, dim_range=dim_range, strides=strides, decode=False, pol=pol, resize_factor=resize_factor, res=res) self.norm_stats_input = None self.norm_stats_output = None def set_norm_stats_output(self, stats): self.norm_stats_output = stats def set_norm_stats_input(self, stats): self.norm_stats_input = stats def __len__(self): return len(self.output_ds) def coordXY(self): return self.output_ds.x, self.output_ds.y def __getitem__(self, item): mean_input, std_input = self.norm_stats_input mean_output, std_output = self.norm_stats_output _output_item = (self.output_ds[item] - mean_output[self.pol]) / std_output[self.pol] _input_item = np.concatenate( ( (self.input_ds[item][0] - mean_input['surf']) / std_input['surf'], (self.input_ds[item][1] - mean_input['lin']) / std_input['lin'], (self.input_ds[item][2] - mean_input['ponct']) / std_input['ponct'], (self.back_ds[item] - mean_input['back']) / std_input['back'], (self.input_ds[item][3] - mean_input['prec']) / std_input['prec'], (self.input_ds[item][4] - mean_input['temp']) / std_input['temp'], (self.input_ds[item][5] - mean_input['windd']) / std_input['windd'], (self.input_ds[item][6] - mean_input['winds']) / std_input['winds'], (self.input_ds[item][7] - mean_input['ray']) / std_input['ray'] ), axis=0 ) if self.use_bati == True: _input_item = np.concatenate( (_input_item, (self.bati_ds[item] - mean_input['bati']) / std_input['bati']),axis=0) input_item = _input_item output_item = _output_item return input_item, output_item class SIRANetDataModule(pl.LightningDataModule): def __init__( self, slice_win, dim_range=None, strides=None, train_slices=(slice('2019-01-13 01:00', "2019-07-31 00:00"),), val_slices=(slice('2019-08-01 01:00', "2019-08-31 00:00"),), test_slices=(slice('2019-10-31 01:00', "2019-12-31 00:00"),), # INPUT input_path='/project/projet321/data_sirane/Entrees_sirane.nc', lin_var = 'lineic', # (time, x, y, Pollutant) surf_var = 'surface', # (time, x, y, Pollutant) ponct_var = 'ponctual', # (time, x, y, Pollutant) temp_var = 'Temp', # time prec_var = 'Precip', # time winds_var = 'U', # time windd_var = 'Dir', # time ray_var = 'SolRad', # time back_var = 'background', # fond use_bati = False, bati_var = "hmoy", # OUTPUT output_path='/project/projet321/data_sirane/Sorties_sirane*.nc', pol_var='Concentration', # (time, x, y, Pollutant) pol='NO2', resize_factor=1, res=10, # OUTPUT COARSE coarse_output_path=None, dl_kwargs=None, ): super().__init__() self.pol = pol self.use_bati = use_bati self.resize_factor = resize_factor self.res = res self.dim_range = dim_range self.slice_win = slice_win self.strides = strides self.dl_kwargs = { **{'batch_size': 16, 'num_workers': 2, 'pin_memory': True}, **(dl_kwargs or {}) } self.input_path = input_path self.pol_var = pol_var self.output_path = output_path self.ray_var = ray_var self.windd_var = windd_var self.winds_var = winds_var self.prec_var = prec_var self.temp_var = temp_var self.ponct_var = ponct_var self.surf_var = surf_var self.lin_var = lin_var self.back_var = back_var self.bati_var = bati_var self.coarse_output_path = coarse_output_path self.train_slices, self.test_slices, self.val_slices = train_slices, test_slices, val_slices self.train_ds, self.val_ds, self.test_ds = None, None, None self.norm_stats_input = None self.norm_stats_output = None def compute_norm_stats_input(self, ds): mean = { 'surf': float(xr.concat([_ds.input_ds.ds[self.surf_var] for _ds in ds.datasets], dim='time').mean()), 'lin': float(xr.concat([_ds.input_ds.ds[self.lin_var] for _ds in ds.datasets], dim='time').mean()), 'ponct': float(xr.concat([_ds.input_ds.ds[self.ponct_var] for _ds in ds.datasets], dim='time').mean()), 'back': float(xr.concat([_ds.back_ds.ds[self.back_var] for _ds in ds.datasets], dim='time').mean()), 'prec': float(xr.concat([_ds.input_ds.ds[self.prec_var] for _ds in ds.datasets], dim='time').mean()), 'temp': float(xr.concat([_ds.input_ds.ds[self.temp_var] for _ds in ds.datasets], dim='time').mean()), 'windd': float(xr.concat([_ds.input_ds.ds[self.windd_var] for _ds in ds.datasets], dim='time').mean()), 'winds': float(xr.concat([_ds.input_ds.ds[self.winds_var] for _ds in ds.datasets], dim='time').mean()), 'ray': float(xr.concat([_ds.input_ds.ds[self.ray_var] for _ds in ds.datasets], dim='time').mean()) } std = { 'surf': float(xr.concat([_ds.input_ds.ds[self.surf_var] for _ds in ds.datasets], dim='time').std()), 'lin': float(xr.concat([_ds.input_ds.ds[self.lin_var] for _ds in ds.datasets], dim='time').std()), 'ponct': float(xr.concat([_ds.input_ds.ds[self.ponct_var] for _ds in ds.datasets], dim='time').std()), 'back': float(xr.concat([_ds.back_ds.ds[self.back_var] for _ds in ds.datasets], dim='time').std()), 'prec': float(xr.concat([_ds.input_ds.ds[self.prec_var] for _ds in ds.datasets], dim='time').std()), 'temp': float(xr.concat([_ds.input_ds.ds[self.temp_var] for _ds in ds.datasets], dim='time').std()), 'windd': float(xr.concat([_ds.input_ds.ds[self.windd_var] for _ds in ds.datasets], dim='time').std()), 'winds': float(xr.concat([_ds.input_ds.ds[self.winds_var] for _ds in ds.datasets], dim='time').std()), 'ray': float(xr.concat([_ds.input_ds.ds[self.ray_var] for _ds in ds.datasets], dim='time').std()) } if self.use_bati == True: mean['bati'] = float(xr.concat([_ds.bati_ds.ds[self.bati_var] for _ds in ds.datasets], dim='time').mean()) std['bati'] = float(xr.concat([_ds.bati_ds.ds[self.bati_var] for _ds in ds.datasets], dim='time').std()) # correction on std if uniform variable for key in std: if std[key]==0.: std[key]=1. return mean, std def set_norm_stats_input(self, ds, ns): for _ds in ds.datasets: _ds.set_norm_stats_input(ns) def compute_norm_stats_output(self, ds): mean = float(xr.concat([_ds.output_ds.ds[self.pol_var] for _ds in ds.datasets], dim='time').mean()) std = float(xr.concat([_ds.output_ds.ds[self.pol_var] for _ds in ds.datasets], dim='time').std()) return {self.pol: mean}, {self.pol: std} def set_norm_stats_output(self, ds, ns): for _ds in ds.datasets: _ds.set_norm_stats_output(ns) def get_domain_bounds(self, ds): min_lon = round(np.min(np.concatenate([_ds.output_ds.ds['x'].values for _ds in ds.datasets])), 2) max_lon = round(np.max(np.concatenate([_ds.output_ds.ds['x'].values for _ds in ds.datasets])), 2) min_lat = round(np.min(np.concatenate([_ds.output_ds.ds['y'].values for _ds in ds.datasets])), 2) max_lat = round(np.max(np.concatenate([_ds.output_ds.ds['y'].values for _ds in ds.datasets])), 2) return min_lon, max_lon, min_lat, max_lat def coordXY(self): return self.test_ds.datasets[0].coordXY() def get_domain_split(self): return self.test_ds.datasets[0].output_ds.ds_size def setup(self, stage=None): self.train_ds, self.val_ds, self.test_ds = [ ConcatDataset( [SIRANetDataset( dim_range={**self.dim_range, **{'time': sl}}, strides=self.strides, slice_win=self.slice_win, input_path=self.input_path, ray_var = self.ray_var, windd_var = self.windd_var, winds_var = self.winds_var, prec_var = self.prec_var, temp_var = self.temp_var, ponct_var = self.ponct_var, surf_var = self.surf_var, lin_var = self.lin_var, back_var = self.back_var, use_bati = self.use_bati, bati_var= self.bati_var, output_path = self.output_path, pol_var = self.pol_var, pol = self.pol, resize_factor=self.resize_factor, res = self.res, coarse_output_path = self.coarse_output_path, ) for sl in slices] ) for slices in (self.train_slices, self.val_slices, self.test_slices) ] self.norm_stats_input = self.compute_norm_stats_input(self.train_ds) self.set_norm_stats_input(self.train_ds, self.norm_stats_input) self.set_norm_stats_input(self.val_ds, self.norm_stats_input) self.set_norm_stats_input(self.test_ds, self.norm_stats_input) self.norm_stats_output = self.compute_norm_stats_output(self.train_ds) self.set_norm_stats_output(self.train_ds, self.norm_stats_output) self.set_norm_stats_output(self.val_ds, self.norm_stats_output) self.set_norm_stats_output(self.test_ds, self.norm_stats_output) self.bounding_box = self.get_domain_bounds(self.train_ds) self.ds_size = self.get_domain_split() def train_dataloader(self): return DataLoader(self.train_ds, **self.dl_kwargs, shuffle=True) def val_dataloader(self): return DataLoader(self.val_ds, **self.dl_kwargs, shuffle=False) def test_dataloader(self): return DataLoader(self.test_ds, **self.dl_kwargs, shuffle=False) </pre> * <span style="color:red"> filière S2 * améliorer le modèle double échelle : diffusion model for LR + CNN-based diffusion for HR </span> * <span style="color:red"> coder la nouvelle façon d'upsampler le modèle coarse (beaucoup plus rapide) </span> * test filière S2 </span>