# Notes on "[Attention Transfer Knowledge Distillation](https://arxiv.org/abs/1612.03928)"
###### tags: `notes` `knowledge-distillation` `attention`
Author: [Akshay Kulkarni](https://akshayk07.weebly.com/)
Note: Actual title of paper is "**Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer**" and it was published as a ICLR '17 poster.
## Brief Outline
By properly defining attention for CNNs, it can be used to significantly improve the performance of a student CNN network by forcing it to mimic the attention maps of a powerful teacher network.
## Introduction
* Artificial attention lets a system “attend” to an object to examine it with greater detail.
* They seek to answer the following question - can a teacher network improve the performance of another student network by providing it information about where it looks, i.e., about where it concentrates its attention into?
* They consider attention as a set of spatial maps that essentially try to encode on which spatial areas of the input the network focuses most for taking its output decision.
* Further, these maps can be defined w.r.t. various layers of the network to capture low, mid, and high-level representation information.
## Methodology
### Activation-based Attention Transfer
* Consider a CNN layer and it's corresponding activation tensor $A \in \mathbb{R}^{C \times H \times W}$. It has $C$ feature planes with spatial dimensions $H \times W$.
* An activation-based mapping function $\mathcal{F}$ (w.r.t. that layer) takes as input the 3D tensor $A$ and outputs a spatial attention map i.e. a flattened 2D tensor defined over the spatial dimensions.
$$
\mathcal{F}:\mathbb{R}^{C \times H \times W} \rightarrow \mathbb{R}^{H \times W}
\tag{1}
$$
* They make an assumption that the absolute value of a hidden neuron activation can be used as an indication of the importance of that neuron w.r.t. the specific input.
* Thus, they consider the following spatial maps in this work:
* Sum of absolute values: $\mathcal{F}_{\text{sum}}(A)=\sum_{i=1}^C |A_i| \tag{2}$
* Sum of absolute values raised to the power of $p$ (where $p > 1$): $\mathcal{F}_{\text{sum}}^p(A)=\sum_{i=1}^C |A_i|^p \tag{3}$
* Maximum of absolute values raised to the power of $p$ (where $p > 1$): $\mathcal{F}_\max^p(A)=\max_{i=1, C} |A_i|^p \tag{4}$
* Here, $A_i = A(i, :, :)$ (using MATLAB notation) and max, power and absolute value operations are elementwise.
#### Visualizing Attention Maps
* From the visualizations, it is found that attention maps focus on different parts for different layers in the network.
* In the first layers, neurons activation level is high for low-level gradient points while in the middle, it is higher for the most discriminative regions such as eyes or wheels, and in the top layers, it reflects full objects.
![Visualization of Attention Maps](https://i.imgur.com/RxJk8P6.png)
* Above image shows $\mathcal{F}_{\text{sum}}$ attention maps at different levels of a trained face recognition network. Mid-level attention maps have higher activation level around eyes, nose and lips, high-level activations correspond to the whole face.
* It is concluded that most discriminative regions have higher activation levels and that shape details disappear as the parameter $p$ increases.
#### Attention Mapping Function Properties
* Compared to $\mathcal{F}_{\text{sum}}(A)$, the spatial map $\mathcal{F}_{\text{sum}}^p(A)$ (where $p>1$) puts more weight to spatial locations that correspond to the neurons with the highest activations, i.e., puts more weight to the most discriminative parts (the larger the $p$ the more focus is placed on those parts with highest activations).
* Furthermore, among all neuron activations corresponding to the same spatial location, $\mathcal{F}_\max^p(A)$ will consider only one of them to assign a weight to that spatial location (as opposed to $\mathcal{F}_{\text{sum}}^p(A)$ that will favor spatial locations that carry multiple neurons with high activations).
#### Training Procedure
* In attention transfer, given the spatial attention maps of a teacher network, the goal is to train a student network that will not only make correct predictions but will also have attentions maps that are similar to those of the teacher.
* They assume that transfer losses are placed between student and teacher attention maps of same shape, but if needed, attention maps can be interpolated to match shapes.
* Let $S, T$ and $W_S, W_T$ denote the student, teacher and their weights respectively and let $\mathcal{L}(W, x)$ denote the standard cross entropy loss.
* Let $\mathcal{I}$ denote the indices of all teacher-student activation layer pairs for which we want to transfer attention maps. Then we can define the following total loss:
$$
\mathcal{L}_{AT} = \mathcal{L}(W_S, x) + \frac{\beta}{2}\sum_{j \in \mathcal{I}}||\frac{Q_S^j}{||Q_S^j||_2} - \frac{Q_T^j}{||Q_T^j||_2}||_p
\tag{5}
$$
* Here, $Q_S^j = \text{vec}(\mathcal{F}(A_S^j))$ and $Q_T^j = \text{vec}(\mathcal{F}(A_T^j))$ are respectively the $j$-th pair of student and teacher attention maps in vectorized form, and $p$ refers to norm type (in the experiments, they use $p$=2).
* They use $l_2$-normalized attention maps i.e. $\frac{Q}{||Q||_2}$ instead of just $Q$. They emphasize that normalizing the attention map is important for the success of this approach.
### Gradient-based Attention Transfer
* They define attention as gradient w.r.t. input which can be viewed as an input sensitivity map i.e. attention at an input spatial location encodes how sensitive the output prediction is w.r.t. changes at that input location.
* Define the gradient of the loss w.r.t. input for teacher and student as
$$
J_S = \frac{\partial}{\partial x}\mathcal{L}(W_S, x) \hspace{10pt} \text{and} \hspace{10pt}
J_T = \frac{\partial}{\partial x}\mathcal{L}(W_T, x)
\tag{6}
$$
* To make student gradient attention similar to the teacher, they minimize a distance between them (they use $l_2$, but other distances can also be used):
$$
\mathcal{L}_{AT}(W_S, W_T, x) = \mathcal{L}(W_S, x) + \frac{\beta}{2}||J_S-J_T||_2
\tag{7}
$$
* As $W_T$ and $x$ are given, to get the needed derivative w.r.t. $W_S$ (this is done by using Eq. 6 and chain rule of derivatives):
$$
\frac{\partial}{\partial W_S}\mathcal{L}_{AT} = \frac{\partial}{\partial W_S}\mathcal{L}(W_S, x) + \beta(J_S-J_T)\frac{\partial^2}{\partial W_S \partial x}\mathcal{L}(W_S, x)
\tag{8}
$$
* To do the weight update, first forward and backprop is done to get $J_S$ and $J_T$. Then, the second error $\frac{\beta}{2}||J_S-J_T||_2$ is computed and propagated backwards a second time.
* This can be implemented efficiently in a framework supporting automatic differentiation.
* They also enforce horizontal flip invariance on gradient attention maps.
* To do that they propagate horizontally flipped images as well as originals, backpropagate and flip gradient attention maps back. Then add $l_2$ losses on the obtained attentions and outputs, and do second backpropagation:
$$
\mathcal{L}_{\text{sym}}(W, x) = \mathcal{L}(W, x) + \frac{\beta}{2}||\frac{\partial}{\partial x}\mathcal{L}(W, x)-\text{flip}(\frac{\partial}{\partial x}\mathcal{L}(W, \text{flip}(x)))||_2
$$
* Here, $\text{flip}(x)$ denotes the flip operator. Experimentally, this had a regularizing effect on training.
* Note that in this work, they consider only gradients w.r.t. the input layer, but in general one might have the proposed attention transfer and symmetry constraints w.r.t. higher layers of the network.
## Experiments
* They demonstrate the performance of the proposed techniques on various image classification datasets. See [paper](https://arxiv.org/abs/1612.03928) for details.
* They report that $\mathcal{F}_{\text{sum}}^2$ performs the best among all activation-based mapping functions.
* They also mention that KD struggles to work if teacher and student have different architecture/depth.
* Code is available on [GitHub](https://github.com/szagoruyko/attention-transfer).
## Conclusion
* They present several ways of transferring attention from one network to another.
* It will be interesting to see how attention transfer works in cases where spatial information is more important like object detection or weakly-supervised localization.