# SRGAN Pytorch實作(訓練階段)
## 導入所需套件
```python=
import os
import numpy as np
import itertools
import glob
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torchvision.models import vgg19
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.transforms import InterpolationMode
from torch.utils.data import Dataset
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torch
```
## 創建所需資料夾
```python=
os.makedirs("images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)
```
## ResidualBlock設計
```python=
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, 0.8),
nn.PReLU(),
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, 0.8),
)
def forward(self, x):
return x + self.conv_block(x)
```
## Generator架構
```python=
class GeneratorResNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
super(GeneratorResNet, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())
res_blocks = []
for _ in range(n_residual_blocks):
res_blocks.append(ResidualBlock(64))
self.res_blocks = nn.Sequential(*res_blocks)
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8))
upsampling = []
for out_features in range(2): #放大4倍
upsampling += [
nn.Conv2d(64, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.PixelShuffle(upscale_factor=2),
nn.PReLU(),
]
self.upsampling = nn.Sequential(*upsampling)
self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())
def forward(self, x):
out1 = self.conv1(x)
out2 = self.res_blocks(out1)
out3 = self.conv2(out2)
out4 = torch.add(out1, out3)
out5 = self.upsampling(out4)
out = self.conv3(out5)
return out
```
## Discriminator架構
```python=
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
self.input_shape = input_shape
in_channels, in_height, in_width = self.input_shape
patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
self.output_shape = (1, patch_h, patch_w)
def discriminator_block(in_filters, out_filters, first_block=False):
layers = []
layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
if not first_block:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
layers = []
in_filters = in_channels #in_filters = 3
for i, out_filters in enumerate([64, 128, 256, 512]):
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
in_filters = out_filters
layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
self.model = nn.Sequential(*layers)
def forward(self, img):
return self.model(img)
```
## FeatureExtractor架構
```python=
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
vgg19_model = vgg19(pretrained=True)
self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])
def forward(self, img):
return self.feature_extractor(img)
```
## 模型創建 & Loss設計
```python=
hr_shape=(128,128)
generator = GeneratorResNet()
discriminator = Discriminator(input_shape=(3,*hr_shape))
feature_extractor = FeatureExtractor()
feature_extractor.eval()
criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()
#content loss from vgg feature
```
## 檢查是否支援cuda
```python=
cuda = torch.cuda.is_available()
if cuda:
generator = generator.cuda()
discriminator = discriminator.cuda()
feature_extractor = feature_extractor.cuda()
criterion_GAN = criterion_GAN.cuda()
criterion_content = criterion_content.cuda()
```
## Optimizer設定
```python=
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
```
## 資料讀取
```python=
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
class ImageDataset(Dataset):
def __init__(self, root, hr_shape):
hr_height, hr_width = hr_shape
#hr_shape=(128, 128)
self.lr_transform = transforms.Compose(
[
#Resize為1/4因目標為放大4倍
transforms.Resize((hr_height // 4, hr_height // 4), InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
self.hr_transform = transforms.Compose(
[
transforms.Resize((hr_height, hr_height), InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
self.files = sorted(glob.glob(root + "/*.*"))
def __getitem__(self, index):
img = Image.open(self.files[index % len(self.files)]).convert('RGB')
#當圖檔為PNG時會有4通道需轉換
img_lr = self.lr_transform(img)
img_hr = self.hr_transform(img)
return {"lr": img_lr, "hr": img_hr}
def __len__(self):
return len(self.files)
```
## DataLoader
```python=
dataloader = DataLoader(
ImageDataset("/train/celebA", hr_shape=hr_shape),
batch_size=4,
shuffle=True,
num_workers=2,
#jupyter notebook多工出問題時可設定成0
)
```
## Training
```python=
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
for epoch in range(0, 1000):
for i, imgs in enumerate(dataloader):
# Configure model input
imgs_lr = Variable(imgs["lr"].type(Tensor))
imgs_hr = Variable(imgs["hr"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# Generate a high resolution image from low resolution input
gen_hr = generator(imgs_lr)
# Adversarial loss
loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr)
loss_content = criterion_content(gen_features, real_features.detach())
# Total loss
loss_G = loss_content + 1e-3 * loss_GAN
loss_G.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Loss of real and fake images
loss_real = criterion_GAN(discriminator(imgs_hr), valid)
loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
# Total loss
loss_D = (loss_real + loss_fake) / 2
loss_D.backward()
optimizer_D.step()
# --------------
# Log Progress
# --------------
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, 1000, i, len(dataloader), loss_D.item(), loss_G.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % 50 == 0:
# Save image grid with upsampled inputs and SRGAN outputs
imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
img_grid = torch.cat((imgs_lr, gen_hr), -1)
save_image(img_grid, "images/%d.png" % batches_done, normalize=False)
if epoch % 100 == 0:
# Save model checkpoints
torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch)
torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" % epoch)
```
---
# SRGAN測試
## 導入訓練完成模型
```python=
generator = GeneratorResNet()
generator.load_state_dict(torch.load("saved_models/generator_100.pth",map_location=torch.device('cpu')))
##在CPU上實現
generator.eval()
```
## 讀入測試圖片
```python=
import PIL.Image as Image
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
img_path="/test/000073.png"
im = Image.open(img_path).convert('RGB')
tr = transform(im)
tr = tr.unsqueeze(0)
tr = Variable(tr.type(Tensor))
```
## 送入Generator
```python=
dst = generator(tr)
#直接以Tensor保存圖片方法
#save_image(dst[0],"dst.png",normalize=True)
unloader = transforms.ToPILImage()
def tensor_to_PIL(tensor):
image = tensor.cpu().clone()
image = image.squeeze(0)
image = ((image+1)*128)/255 #-1到1需歸一成0到255
image = unloader(image)
return image
#轉換成PIL方便進行圖片查看
dst = tensor_to_PIL(dst)
```
## 原始圖片(放大4倍)

## SR生成圖片

###### tags: `SRGAN` `影像超解析` `pytorch`