Visualization in Image Captioning

tags: Deep Learning for Computer Vision

We load pre-trained transformer model to implement image Captioning Visualization.

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

FULL TRANSFORMER NETWORK FOR IMAGE CAPTIONING

Visualize Attention Map

The output feature of the last decoder layer is utilized to predict next word via a linear layer whose output dimension equals to the vocabulary size. we take one example image below to show the caption predicted by model and visualize the "words-to-patches" cross attention weights in the decoder.

According to torch.nn.MultiheadAttention code. We can direct get the attn_output_weights to visualize attention Map

https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

https://github.com/pytorch/pytorch/blob/78b7a419b2b79152aef0fe91c726009acb0595af/torch/nn/functional.py#L4960

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

Dimension

  • Query embeddings of shape \((L, N, E_q)\)
    • where \(L\) is the target sequence length, \(N\) is the batch size, and \(E_q\) is the query embedding dimension
  • Key embeddings of shape \((S, N, E_k)\)
    • where \(S\) is the source sequence length, \(N\) is the batch size, and \(E_k\) is the key embedding dimension
  • Value embeddings of shape \((S, N, E_v)\)
    • where \(S\) is the source sequence length, \(N\) is the batch size, and \(E_v\) is the value embedding dimension
  • Attention outputs of shape \((L, N, E)\)
    • where \(L\) is the target sequence length, \(N\) is the batch size, and \(E\) is the embedding dimension
  • Attention output weights of shape \((N, L, S)\)
    • where \(N\) is the batch size, \(L\) is the target sequence length, and \(S\) is the source sequence length.

In our case :

  • query : (128, 1, 256)
  • key : (361, 1, 256)
  • value : (361, 1, 256)
  • output : (128, 1, 256)
  • weight : (1, 128, 361)

where source sequence length \(S\) 128 is max position embeddings.

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

Result

Visualization of the attention weights computed by the “words-to-patches” cross attention in the last decoder layer. “A young girl holding up a slice of pizza.” is the caption generated by our model. We can see both "girl" and "pizza" are correspond with each word.

Different from image classification only focus on the first row \(q\cdot k\). If we want show attention map in image captioning, we need to consider all metrix \(q\cdot k\).

Select a repo