:::success
### Finished architecture :
```python =
class VectorQuantizer(nn.Module):
"""
This is the quantizer Block inspired by the git repository :
https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
beta: float = 0.25):
super(VectorQuantizer, self).__init__()
self.K = num_embeddings
self.D = embedding_dim
self.beta = beta
self.embedding = nn.Embedding(self.K, self.D)
self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)
def forward(self, latents: Tensor) -> Tensor:
latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D]
latents_shape = latents.shape
flat_latents = latents.view(-1, self.D) # [BHW x D]
# Compute L2 distance between latents and embedding weights
dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim=1) - \
2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHW x K]
# Get the encoding that has the min distance
encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1]
# Convert to one-hot encodings
device = latents.device
encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K]
# Quantize the latents
quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D]
quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D]
# Compute the VQ Losses
commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
embedding_loss = F.mse_loss(quantized_latents, latents.detach())
vq_loss = commitment_loss * self.beta + embedding_loss
# Add the residue back to the latents
quantized_latents = latents + (quantized_latents - latents).detach()
return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W]
class ResidualLayer(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int):
super(ResidualLayer, self).__init__()
self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels,
kernel_size=3, padding=1, bias=False),
nn.ReLU(True),
nn.Conv2d(out_channels, out_channels,
kernel_size=1, bias=False))
def forward(self, input: Tensor) -> Tensor:
return input + self.resblock(input)
class VQVAE(nn.Module):
def __init__(self,
in_channels: int,
embedding_dim: int,
num_embeddings: int,
hidden_dims: List = None,
beta: float = 0.25,
img_size: int = 64,
**kwargs) -> None:
super(VQVAE, self).__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.img_size = img_size
self.beta = beta
modules = []
if hidden_dims is None:
hidden_dims = [128, 256]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size=4, stride=2, padding=1),
nn.LeakyReLU())
)
in_channels = h_dim
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, in_channels,
kernel_size=3, stride=1, padding=1),
nn.LeakyReLU())
)
for _ in range(6):
modules.append(ResidualLayer(in_channels, in_channels))
modules.append(nn.LeakyReLU())
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, embedding_dim,
kernel_size=1, stride=1),
nn.LeakyReLU())
)
self.encoder = nn.Sequential(*modules)
self.vq_layer = VectorQuantizer(num_embeddings,
embedding_dim,
self.beta)
# Build Decoder
modules = []
modules.append(
nn.Sequential(
nn.Conv2d(embedding_dim,
hidden_dims[-1],
kernel_size=3,
stride=1,
padding=1),
nn.LeakyReLU())
)
for _ in range(6):
modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))
modules.append(nn.LeakyReLU())
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=4,
stride=2,
padding=1),
nn.LeakyReLU())
)
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
out_channels=4,
kernel_size=4,
stride=2, padding=1),
nn.Tanh()))
self.decoder = nn.Sequential(*modules)
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
return [result]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D x H x W]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder(z)
return result
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
encoding = self.encode(input)[0]
quantized_inputs, vq_loss = self.vq_layer(encoding)
return [self.decode(quantized_inputs), input, vq_loss]
def loss_function(self,
*args,
**kwargs) -> dict:
"""
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
vq_loss = args[2]
recons_loss = F.binary_cross_entropy_with_logits(input = recons,target= input)
loss = recons_loss + vq_loss
return {'loss': loss,
'Reconstruction_Loss': recons_loss,
'VQ_Loss':vq_loss}
# def sample(self,
# num_samples: int,
# current_device: Union[int, str], **kwargs) -> Tensor:
# raise Warning('VQVAE sampler is not implemented.')
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return (self.forward(x)[0] > 0.5 ) # Since we are dealing with binary image.
:::
:::warning
## working on :
- coding the training Loop : learning_rate? optimizer ? the BCE loss around 1 is it ok ? ..
- fixing the remote VsCode (OpenVPN doesnt work for me)
:::