--- title: Universe VAE Implementations --- ### Universe VAE Implementations **Data** Random cropped 1x 128 x 128 (C, H, W) **need to deal with periodic boundary**, large structure image slides from 3D cubes (It's better to start from 1 x 128 x 128 first, and we can try high-resolution ones later) Image transformation: $\log(y^n_{m} + \epsilon)$, let $\epsilon = 10^{-8}$, $n$ is the n-th cube and $m$ is the m-th slice. **Visuals** ![](https://i.imgur.com/bhAWpWs.png) Image pixel value distribution ![](https://i.imgur.com/hWxUtzA.png) **Model Forward Function** ``` def reparameterize(self, x): mu, logvar = x.chunk(2, dim=1) std = logvar.mul(0.5).exp_() eps = torch.randn_like(std) output = eps.mul(std).add_(mu) return output, mu, logvar def forward(self, img_zs, img_h): img_z = img_zs.view(-1, 1, 128, 128) zero_vector = torch.zeros((img_z.shape[0], 512)).type_as(img_z) zs, _ = self.encoder(img_z, zero_vector) z = self.deep_set(zs) z, z_mu, z_logvar = self.reparameterize(z) h = self.decoder_z(z) h, h_mu, h_logvar = self.reparameterize(h) _, h_post = self.encoder(img_h, z) h_post, h_post_mu, h_post_logvar = self.reparameterize(h_post) recon_h = self.decoder_h(h_post) return recon_h, (z, z_mu, z_logvar), (h, h_mu, h_logvar), (h_post, h_post_mu, h_post_logvar) ``` **Implementation Details** # | Name | Type | Params | --|-----------|----------|--------| 0 | encoder | ResNet | 11 M | 1 | deep_set | DeepSet | 279 K | 2 | decoder_h | DecoderH | 11 M | 3 | decoder_z | DecoderZ | 214 K | Note: regarding the sizes, the `batch_size` is always the first dimension and will be **ignored** when the size is mentioned. **Step 1** `encoder` encodes `img_z` to global representations `latent_zs` and feed those latent features to `deep_set`. ``` zs, _ = self.encoder(img_z, zero_vector) # output_size 6 x 1024 ``` - `img_z` of size (6, 1, 128, 128) (#slices, C, H, W) are images drawn from the same cube, by default the amount is 3, but I have set it as a hyperparameter to tune later if necessary. - - `latent_zs` are of size (6, 1024), 3 is how many slices I select from a cube, and 1024 is the latent size, 512 is $\mu_z$ and 512 is $\log{\Sigma_z}$. - `encoder` is the typical structure of `Resnet18`. **Step 2** `deep_set` sum those `zs` up and generate `z`, with the size (1024,) ``` z = self.deep_set(zs) # output size 1024 z, z_mu, z_logvar = self.reparameterize(z) # shape 512, 512, 512, and z is the sample ``` **Step 3** `h` is conditioned on `z` ``` h = self.decoder_z(z) input shape 512, output shape 1024 h, h_mu, h_logvar = self.reparameterize(h) shape 512, 512, 512, and h is the sample ``` **Step 4** `encoder` encodes a local/sliced image `img_h` to representation `h_post` conditioned on `z`. ``` _, h_post = self.encoder(img_h, z) h_post, h_post_mu, h_post_logvar = self.reparameterize(h_post) ``` `img_h` is of size (1, 1, 128, 128) and `h_post` shape 512 **Step5** The decoder decodes `h_post` to reconstructed images defined as `recon_h` ``` recon_h = self.decoder_h(h_post) # output size 1 x 128 x 128 ``` **Full Steps** ``` def reparameterize(self, x): mu, logvar = x.chunk(2, dim=1) std = logvar.mul(0.5).exp_() eps = torch.randn_like(std) output = eps.mul(std).add_(mu) return output, mu, logvar def forward(self, img_zs, img_h): img_z = img_zs.view(-1, 1, 128, 128) zero_vector = torch.zeros((img_z.shape[0], 512)).type_as(img_z) zs, _ = self.encoder(img_z, zero_vector) z = self.deep_set(zs) z, z_mu, z_logvar = self.reparameterize(z) h = self.decoder_z(z) h, h_mu, h_logvar = self.reparameterize(h) _, h_post = self.encoder(img_h, z) h_post, h_post_mu, h_post_logvar = self.reparameterize(h_post) recon_h = self.decoder_h(h_post) return recon_h, (z, z_mu, z_logvar), (h, h_mu, h_logvar), (h_post, h_post_mu, h_post_logvar) ``` **Loss Function** l2 mse loss KL loss: - KL of $z$ and N(0, 1) - KL of $h_{post}$ and N($\mu_h$, $\Sigma_h$) ``` recon_h, (z, z_mu, z_logvar), (h, h_mu, h_logvar), (h_post, h_post_mu, h_post_logvar) = self(img_zs, img_h) l2 = self.mse_loss(recon_h, img_h) kl_z = self.kl_loss(None, None, z_mu, z_logvar, kl_z=True) kl_h = self.kl_loss(h_mu, h_logvar, h_post_mu, h_post_logvar, kl_z=False) kl_loss = self.hparams.kl_z * kl_z + self.hparams.kl_h * kl_h loss = l2 + kl_loss ``` Defined below: ``` def mse_loss(self, recon_h, img_h): return F.mse_loss(recon_h, img_h, reduction='sum') / recon_h.size()[0] def kl_loss(self, q_mu, q_logvar, p_mu, p_logvar, kl_z=False): p = torch.distributions.normal.Normal(p_mu, p_logvar.mul(0.5).exp_()) if kl_z: q = torch.distributions.normal.Normal(0., 1.) else: q = torch.distributions.normal.Normal(q_mu, q_logvar.mul(0.5).exp_()) return torch.distributions.kl.kl_divergence(p, q).mean() ```