---
title: Universe VAE Implementations
---
### Universe VAE Implementations
**Data**
Random cropped 1x 128 x 128 (C, H, W) **need to deal with periodic boundary**, large structure image slides from 3D cubes
(It's better to start from 1 x 128 x 128 first, and we can try high-resolution ones later)
Image transformation: $\log(y^n_{m} + \epsilon)$, let $\epsilon = 10^{-8}$, $n$ is the n-th cube and $m$ is the m-th slice.
**Visuals**

Image pixel value distribution

**Model Forward Function**
```
def reparameterize(self, x):
mu, logvar = x.chunk(2, dim=1)
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std)
output = eps.mul(std).add_(mu)
return output, mu, logvar
def forward(self, img_zs, img_h):
img_z = img_zs.view(-1, 1, 128, 128)
zero_vector = torch.zeros((img_z.shape[0], 512)).type_as(img_z)
zs, _ = self.encoder(img_z, zero_vector)
z = self.deep_set(zs)
z, z_mu, z_logvar = self.reparameterize(z)
h = self.decoder_z(z)
h, h_mu, h_logvar = self.reparameterize(h)
_, h_post = self.encoder(img_h, z)
h_post, h_post_mu, h_post_logvar = self.reparameterize(h_post)
recon_h = self.decoder_h(h_post)
return recon_h, (z, z_mu, z_logvar), (h, h_mu, h_logvar), (h_post, h_post_mu, h_post_logvar)
```
**Implementation Details**
# | Name | Type | Params |
--|-----------|----------|--------|
0 | encoder | ResNet | 11 M |
1 | deep_set | DeepSet | 279 K |
2 | decoder_h | DecoderH | 11 M |
3 | decoder_z | DecoderZ | 214 K |
Note: regarding the sizes, the `batch_size` is always the first dimension and will be **ignored** when the size is mentioned.
**Step 1** `encoder` encodes `img_z` to global representations `latent_zs` and feed those latent features to `deep_set`.
```
zs, _ = self.encoder(img_z, zero_vector) # output_size 6 x 1024
```
- `img_z` of size (6, 1, 128, 128) (#slices, C, H, W) are images drawn from the same cube, by default the amount is 3, but I have set it as a hyperparameter to tune later if necessary.
-
- `latent_zs` are of size (6, 1024), 3 is how many slices I select from a cube, and 1024 is the latent size, 512 is $\mu_z$ and 512 is $\log{\Sigma_z}$.
- `encoder` is the typical structure of `Resnet18`.
**Step 2** `deep_set` sum those `zs` up and generate `z`, with the size (1024,)
```
z = self.deep_set(zs) # output size 1024
z, z_mu, z_logvar = self.reparameterize(z) # shape 512, 512, 512, and z is the sample
```
**Step 3** `h` is conditioned on `z`
```
h = self.decoder_z(z) input shape 512, output shape 1024
h, h_mu, h_logvar = self.reparameterize(h) shape 512, 512, 512, and h is the sample
```
**Step 4** `encoder` encodes a local/sliced image `img_h` to representation `h_post` conditioned on `z`.
```
_, h_post = self.encoder(img_h, z)
h_post, h_post_mu, h_post_logvar = self.reparameterize(h_post)
```
`img_h` is of size (1, 1, 128, 128) and `h_post` shape 512
**Step5** The decoder decodes `h_post` to reconstructed images defined as `recon_h`
```
recon_h = self.decoder_h(h_post) # output size 1 x 128 x 128
```
**Full Steps**
```
def reparameterize(self, x):
mu, logvar = x.chunk(2, dim=1)
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std)
output = eps.mul(std).add_(mu)
return output, mu, logvar
def forward(self, img_zs, img_h):
img_z = img_zs.view(-1, 1, 128, 128)
zero_vector = torch.zeros((img_z.shape[0], 512)).type_as(img_z)
zs, _ = self.encoder(img_z, zero_vector)
z = self.deep_set(zs)
z, z_mu, z_logvar = self.reparameterize(z)
h = self.decoder_z(z)
h, h_mu, h_logvar = self.reparameterize(h)
_, h_post = self.encoder(img_h, z)
h_post, h_post_mu, h_post_logvar = self.reparameterize(h_post)
recon_h = self.decoder_h(h_post)
return recon_h, (z, z_mu, z_logvar), (h, h_mu, h_logvar), (h_post, h_post_mu, h_post_logvar)
```
**Loss Function**
l2 mse loss
KL loss:
- KL of $z$ and N(0, 1)
- KL of $h_{post}$ and N($\mu_h$, $\Sigma_h$)
```
recon_h, (z, z_mu, z_logvar), (h, h_mu, h_logvar), (h_post, h_post_mu, h_post_logvar) = self(img_zs, img_h)
l2 = self.mse_loss(recon_h, img_h)
kl_z = self.kl_loss(None, None, z_mu, z_logvar, kl_z=True)
kl_h = self.kl_loss(h_mu, h_logvar, h_post_mu, h_post_logvar, kl_z=False)
kl_loss = self.hparams.kl_z * kl_z + self.hparams.kl_h * kl_h
loss = l2 + kl_loss
```
Defined below:
```
def mse_loss(self, recon_h, img_h):
return F.mse_loss(recon_h, img_h, reduction='sum') / recon_h.size()[0]
def kl_loss(self, q_mu, q_logvar, p_mu, p_logvar, kl_z=False):
p = torch.distributions.normal.Normal(p_mu, p_logvar.mul(0.5).exp_())
if kl_z:
q = torch.distributions.normal.Normal(0., 1.)
else:
q = torch.distributions.normal.Normal(q_mu, q_logvar.mul(0.5).exp_())
return torch.distributions.kl.kl_divergence(p, q).mean()
```