owned this note
owned this note
Published
Linked with GitHub
# Visualization in Image Captioning
###### tags: `Deep Learning for Computer Vision`
We load pre-trained **transformer** model to implement image Captioning Visualization.

[FULL TRANSFORMER NETWORK FOR IMAGE CAPTIONING](https://arxiv.org/pdf/2101.10804.pdf)
## Visualize Attention Map
The output feature of the last decoder layer is utilized to predict next word via a linear layer whose output dimension equals to the vocabulary size. we take one example image below to show the caption predicted by model and visualize the "words-to-patches" cross attention weights in the decoder.
According to torch.nn.MultiheadAttention code. We can direct get the attn_output_weights to visualize attention Map
https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py

https://github.com/pytorch/pytorch/blob/78b7a419b2b79152aef0fe91c726009acb0595af/torch/nn/functional.py#L4960

## Dimension
* Query embeddings of shape $(L, N, E_q)$
* where $L$ is the target sequence length, $N$ is the batch size, and $E_q$ is the query embedding dimension
* Key embeddings of shape $(S, N, E_k)$
* where $S$ is the source sequence length, $N$ is the batch size, and $E_k$ is the key embedding dimension
* Value embeddings of shape $(S, N, E_v)$
* where $S$ is the source sequence length, $N$ is the batch size, and $E_v$ is the value embedding dimension
* Attention outputs of shape $(L, N, E)$
* where $L$ is the target sequence length, $N$ is the batch size, and $E$ is the embedding dimension
* Attention output weights of shape $(N, L, S)$
* where $N$ is the batch size, $L$ is the target sequence length, and $S$ is the source sequence length.
In our case :
* query : (128, 1, 256)
* key : (361, 1, 256)
* value : (361, 1, 256)
* output : (128, 1, 256)
* weight : (1, 128, 361)
where source sequence length $S$ 128 is max position embeddings.
<img src="https://i.imgur.com/6zv0Yb2.png" width="450"/>
## Result
Visualization of the attention weights computed by the “words-to-patches” cross attention in the last decoder layer. “A young girl holding up a slice of pizza.” is the caption generated by our model. We can see both "girl" and "pizza" are correspond with each word.
Different from image classification only focus on the first row $q\cdot k$. If we want show attention map in image captioning, we need to consider all metrix $q\cdot k$.
