# Pretext Invariant Representation Learning Authors: Ishan Misra and Laurens van der Maaten **Introduction** Conventional pretext tasks training such as jigsaw puzzles, rotations etc learn image representations that are covariant to the transformations applied to the image. Pretext invariant representation learning (PIRL) constructs image representations such that original image and transformed image have similar representation, while being different from that of other images, thus achieving invariance to the transformation applied. **Approach Overview** - Uses jigsaw puzzles transformations to form transformed versions of the input image, i.e. divides the image into nine patches and randomly permutes the patches thereafter. - Trains CNN $\phi_\theta(.)$ with parameters $\theta$ that constructs image representations $v_I$ = $\phi_\theta(I)$ that are invariant to image transformations $t\in T$, I.e. produces image representations such that $\phi_\theta(I) \approx \phi_\theta(I^t)$. - The CNN for their formulation is trained using NCE (Noise Contrastive Estimation) technique, NCE models the porbability of event that $(I, I^t)$ originate from the same data distribution, I.e. - $$h(v_I, v_{I^t}) = \frac{exp(s(v_I, v_{I^t}) / \tau)}{exp(s(v_I, v_{I^t}) / \tau) + \Sigma_{I' \in D_N} exp(s(v_{I^t}, v_{I'})/\tau)} - (1)$$ - Here, $D_N$ is set of N negative samples, and $I'$ sample image drawn from $D_N$, while $s(.,.)$ is cosine similarity between the feature vectors. - In practice the CNN features $v$ are not directly used. Rather, functions $f()$ and $g()$ are applied to the CNN features before forwarding these to the loss function. Function $f$ is used on activations of original images, while $g$ is used on activations of the transformed image. - The NCE loss function aims to maximise $h(f(v_I), g(v_{I^t}))$ and minimize $h(g(v_{I^t}), f(v_{I'}))$ Thus, basis (1), loss function for the NCE technique can be expressed as: - $L_{NCE} = -\log[h(f(v_I), g(v_{I^t}))] - \Sigma_{I' \in D_n} \log[1 - h(g(v_{I^t}), f(v_{I'}))]$ - **Memory bank of Negatives**: - In a mini batch SGD setting, an infeasibly large batch size is required to fit in large number of -ve samples to train the NCE loss - Thus, they build a memory bank M, with features $m_I$ , where $m_I$ is the exponential moving average of features $f(v_I)$ for image $I$ across previous epochs, and replace $f(v_{I^{'}})$ with $m_I'$ in $L_{NCE}$. - **Final Loss Function** - Present issue with $L_{NCE}$ is that it does not compare $f(v_I)$ with $f(v_{I^{'}})$, which can be addressed by setting final loss function to: - $$L(I, I') = \lambda L_{NCE}(m_I, g(v_{I^t})) + (1-\lambda)L_{NCE}(m_I, f(v_I)) $$ - The instroduction of second term has 2 effects: (1): It encourages feature representation $f(v_I)$ to be close to its memory representaion $m_I$, thereby having a regularization effect on weight updates. (2) It encourages $f(v_I)$ to be far from memory representation of other images $I'$. **Implementation Details** - Base network used for Experimentation: ResNet50. - $f(v_I)$ := res5 features from network => average pooling => linear projection to 128d feature vector. - $g(v_{I^t})$ := Same as computing $f(v_I)$ for each of the 9 permuted image patches => Concatenate all 9 feature vectors => linear projection to 128d feature vector. - **Hyper-parameters for loss function**: Temperature $\tau$ in eqn of $h(v_x, v_y)$ is set to 0.07, weight for computing moving averages is set to 0.5, and $\lambda$ in the final loss function is set to 0.5. **Experiments** **Pre-text training details**: Train data: ImageNet Train set, Optimizer: SGD; initial learning rate (lr): 0.12; final lr: 1.2 x $10^{-4}$; lr decay scheme: cosine decay; epochs: 800; batch size: 1024; N=32000 negative samples. **Object detection results** - PIRL ($AP^{75} = 59.7$) outperformed all other SSL methods (jigsaws, rotations, NPID++). except MoCo, which was marginally better, and also outperformed ImageNet pre-trained models on $AP^{all}, AP^{75}$. **Image Classificiation results (on fixed representations)** - PIRL achieves the highest (63.6%) top-1 accuracy on ImageNet compared to other SSL methods. **Qualitative comparison with covariant counter-part (Jigsaw puzzles)**: **Best performing layer** - Quality of jigsaw representations improve from conv1 to res4, but, sharply decreases for res5, as res5 layers' weights learn more task-specific representations (being covariant of the transformation applied). - While, PIRL's objective focuses on semantic information and not covary with transformation applied. Hence, best image representations are obtained from the Res5 layers of PIRL. **Number of image transformations that can be applied** - In jigsaws pre-text task, training with large number of image transformations becomes prohibitive because, no of parameters in last layer grows linearly with the no of permutations used. - This problem doesn't apply to PIRL since it doesn't output the permutation index for a transformation. Thus, PIRL can leverage very large number of possible permutations.