# Fine-Tuning can Distort Pretrained Features and Underperform Out-of-Distribution ###### tags: `Research`, `fine-tune`, `PTMs` > [Paper link](https://arxiv.org/pdf/2202.10054.pdf) | [Code link](https://github.com/AnanyaKumar/transfer_learning) | ICLR 2022 | Seminar 2023/03/02 ## Abstract In this paper, they find that **fine-tuning can achieve worse accuracy than linear probing out-of-distribution (OOD)** when the pretrained features are good and the distribution shift is large. They prove that the OOD error of fine-tuning is high when the model initialize with a fixed or random head—this is because **while fine-tuning learns the head, the lower layers of the neural network change simultaneously and distort the pretrained features.** This paper suggests that the easy **two-step strategy of linear probing then full fine-tuning (LP-FT)**, sometimes used as a fine-tuning heuristic, combines the benefits of both fine-tuning and linear probing. ![](https://i.imgur.com/dG7Iq1i.png) ## Introduction Pretraining a model on a large dataset before transferring to a downstream task’s training data substantially improves accuracy over training from scratch. But, achieving high in-distribution accuracy is not enough, they also require models that also generalize to circumstances not seen in the training distribution. Such that, it is increasingly important to test on **data distributions unseen during training (out-of-distribution; OOD).** To overcome the problem, they have two popular transfer methods, **fine-tuning** (running gradient descent on all the model parameters), and **linear probing** (tuning the head but freezing lower layers). In this work, they investigate the OOD accuracy of fine-tuning and linear probing and find that surprisingly, **fine-tuning can do worse than linear probing in the presence of large distribution shift.** :::warning :warning: Why ? This paper theoretically consider **fine-tuning a two-layer linear network in an overparameterized regression setting** where the feature extractor layer has been pretrained to map high-dimensional inputs to useful, lower-dimensional, features. And proving that fine-tuning is worse than linear probing on directions outside the span of the training data when using “good” pretrained features. Since, linear-probing extrapolates better OOD because it preserves pretrained features. ::: **Algorithmic implications** This paper shows that fine-tuning underpeforms because when trying to fit ID training data with a randomly initialized head, the feature extractor changes significantly for ID examples, making features for ID and OOD examples largely inconsistent. This can be fixed by **initializing with a good head** that does not need to be updated much during fine-tuning, reducing how much the feature extractor changes. Such that, this paper suggests a simple two-step strategy of first linear-probing to find a good head and then full fine-tuning (LP-FT). **Empirical validation** As predicted by the theory, they find that 1. Fine-tuning indeed never matches the OOD accuracy of linear probing throughout the course of training (if the pretrained features are good, and OOD shift is large) 2. Fine-tuning changes the features for ID examples more than for OOD examples, leading to distortions 3. LP-FT indeed changes both ID and OOD features $10 \times$ to $100 \times$ less than fine-tuning does 4. Fine-tuning can do better than linear probing OOD if the pretrained features are not very high quality 5. LP-FT gets the best of both worlds, better accuracies than fine-tuning and linear probing, both ID and OOD ## Setup For some loss function $l$, they evaluate classifiers on: $$ L_{\text {id }}(f)=\underset{(x, y) \sim P_{\text {id }}}{\mathbb{E}}[\ell(f(x), y)] \text { and } L_{\mathrm{ood}}(f)=\underset{(x, y) \sim P_{\mathrm{ood}}}{\mathbb{E}}[\ell(f(x), y)] $$ **Models.** In this work, they focus on predictors that leverage pretrained representations. They parameterize the final predictor $f$ as follows: given features $g_B(x) \in \mathbb{R}^k$ for some feature extractor parameters $B \in \mathcal{B}$, and a linear "head" $v \in \mathcal{V}$, they have $f_{v, B}(x)=v^{\top} g_B(x)$. In their experiments, $g_B$ is a deep network and in the theory, $g_B$ is a linear projection. They focus on two popular methods to learn a predictor $f_{v, B}$ given training data from $P_{\text {id }}$ : 1. linear probing where $B=B_0$ and the linear head is obtained by minimizing some loss (e.g., logistic loss for classification, squared loss for regression) on the training data 2. fine-tuning where both $v$ and $B$ are updated by performing gradient descent on some loss on the training data with $B$ initialized at $B_0$. ## Theory: fine-tuning distorts pretrained features This section prove their main result: that fine-tuning, in which all model parameters are updated, distorts features and gets suboptimal OOD error. They use this result to show that linear probing gets better OOD error but worse ID error than fine-tuning. Finally, they explain why linear probing then fine-tuning can mitigate this ID-OOD tradeoff. ### Linear overparameterized setting For analysis, they focus on regression, where $\mathcal{Y}=\mathbb{R}$ and $\ell(\widehat{y}, y)=(\widehat{y}-y)^2$ is the squared loss. **Models** Recall from Section 2 that they parameterize predictors in terms of **feature extractor** and **head parameters**. In this section, they study models where the **feature extractor is linear**, i.e. $f_{v, B}(x)=v^{\top} B x$ where $B \in \mathcal{B}=\mathbb{R}^{k \times d}$, and $v \in \mathcal{V}=\mathbb{R}^k$. **Good pretrained features** For simplicity, they assume the models are well-specified i.e. $y=v_{\star}^{\top} B_{\star} x$ where $v_{\star} \in \mathbb{R}^k$ and $B_{\star} \in \mathbb{R}^{k \times d} .$ Note that $B_{\star}$ and $v_{\star}$ are only unique up to rotations, i.e., for any rotation matrix $U,\left(U v_{\star}\right)^T\left(U B_{\star}\right) x=v_{\star}^T B_{\star} x$. As in prior work suppose $B_{\star}, B_0$ have been orthogonalized to have orthonormal rows. Suppose they have a pretrained feature extractor $B_0$ close to $B_{\star}$, so $d\left(B_0, B_{\star}\right) \leq \epsilon$ where the distance $d$ is defined below: **Definition 3.1 (Feature Extractor Distance)** The distance between feature extractors $B, B^{\prime} \in \mathbb{R}^{k \times d}$ (with orthonormal rows) is given by (where the min is over rotation matrices $U \in \mathbb{R}^{k \times k}$ ): $$ d\left(B, B^{\prime}\right)=\min _U\left\|B-U B^{\prime}\right\|_2, $$ **Pretraining coverage intuition** 1. There exists a shared set of useful features for ID ($P_{\mathrm{id}}$) and OOD ($P_{\mathrm{ood}}$) 2. $B_0$ is close to $B_*$ In this paper they show that even if the model have these good features, fine-tuning can distort them and lead to low OOD accuracy. **Training methods** For training loss $\widehat{L}(v, B)=\left\|X B^{\top} v-Y\right\|_2^2$, the gradient flow differential equations for LP and FT are as follows: $$ \begin{gathered} \partial_t v_{\mathrm{ft}}(t)=-\nabla_v \widehat{L}\left(v_{\mathrm{ft}}(t), B_{\mathrm{ft}}(t)\right), \partial_t B_{\mathrm{ft}}(t)=-\nabla_B \widehat{L}\left(v_{\mathrm{ft}}(t), B_{\mathrm{ft}}(t)\right), \\ \partial_t v_{\mathrm{lp}}(t)=-\nabla_v \widehat{L}\left(v_{\mathrm{lp}}(t), B_0\right), \partial_t B_{\mathrm{lp}}(t)=0, \end{gathered} $$ initialized with $B_{\mathrm{ft}}(0)=B_{\mathrm{lp}}(0)=B_0$ and $v_{\mathrm{ft}}(0)=v_{\mathrm{lp}}(0)=v_0$. In practice, the head parameter $v_0$ is initialized randomly - our results hold for any standard random initialization, for example $v_0 \sim \mathcal{N}\left(0, \sigma^2 I\right)$ for any $\sigma^2$, or zero initialization where $v_0=0$. Recall that the initial value of the feature extractor $B_0$ is obtained via pretraining. The final LP and FT solutions are the limit points of the corresponding gradient flows: $$ \begin{aligned} & v_{\mathrm{ft}}^{\infty}=\lim _{t \rightarrow \infty} v_{\mathrm{ft}}(t) \text { and } B_{\mathrm{ft}}^{\infty}=\lim _{t \rightarrow \infty} B_{\mathrm{ft}}(t), \\ & v_{\mathrm{lp}}^{\infty}=\lim _{t \rightarrow \infty} v_{\mathrm{lp}}(t) \text { and } B_{\mathrm{lp}}^{\infty}=\lim _{t \rightarrow \infty} B_{\mathrm{lp}}(t)=B_0 . \end{aligned} $$ ![](https://i.imgur.com/EPfX4Et.png) ### Fine-tuning distorts pretrained features They first present the key intuitions demonstrating potential issues of FT and then present our formal theorem lower bounding the OOD error of FT. **Key intuitions** 1. **Features get distorted**: representations change only in the ID subspace (i.e., subspace spanned by the training data) and are unchanged in the orthogonal subspace. To see this, they take the derivative of the training loss $\widehat{L}(v, B)=\left\|X B^{\top} v-Y\right\|_2^2$ with respect to the feature extractor parameter $B$ : $$ \nabla_B \widehat{L}(v, B)=2 v(Y-X B v)^{\top} X $$ By definition, if $u$ is a direction orthogonal to the training subspace $S=\operatorname{rowspace}(X)$, then $\nabla_B \widehat{L}(v, B) u=$ 0 , that is the gradient updates to $B$ do not modify $B u$ for $u \in S^{\perp}$. However, the gradient is non-zero for directions $u$ in the ID subspace and the corresponding features $B u$ change across the fine-tuning process. They call this feature distortion: the features in some directions are changed but not others. ![](https://i.imgur.com/8cjIa0n.png) 2. **Distorted features can lead to higher OOD error** ==The only way the pretrained features are not distorted and only scaled during FT is if the initial feature extractor $B_0$ is exactly aligned with the ID subspace.== In high dimensions, they measure the alignment between $B_0$ and the ID subspace with the largest principal angle: **Definition 3.2 (largest principal angle).** Let $A$ and $B$ be arbitrary subspaces, and $E$ and $F$ be matrices with orthonormal columns than span $A$ and $B$ respectively, with $r=\min (\operatorname{dim}(A), \operatorname{dim}(B))$. Then $\cos \theta_{\max }(A, B)=\sigma_r (E^{\top} F)$, which is the $r$-th largest singular value of $E^{\top} F$. Note that $E, F$ are not unique in Definition 3.2, but $\sigma_r(E^{\top} F)$ is the same for every valid choice of $E$ and $F$. **General result on the OOD error of fine-tuning** Our main theorem lower bounds the OOD error of fine-tuning outside the span of the training data. In Section $3.3$ they compare this lower bound with an upper bound on the OOD error of linear probing. **Theorem 3.3.** In the overparameterized linear setting, let $S^{\perp}=\operatorname{rowspace}(X)^{\perp}, R_0=\operatorname{rowspace}\left(B_0\right)$, and $v_{\star}, B_{\star}$ be the optimal parameters with $w_{\star}=B_{\star} v_{\star}$. If $\cos \theta_{\max }\left(R_0, S^{\perp}\right)>0$, then for all time steps $t$, **the OOD error of the fine-tuning** iterates $\left(B_{\mathrm{ft}}(t), v_{\mathrm{ft}}(t)\right)$ is lower bounded: $$ \sqrt{L_{\mathrm{ood}}\left(v_{\mathrm{ft}}(t), B_{\mathrm{ft}}(t)\right)} \geq \sqrt{\sigma_{\min }(\Sigma)}\left(\frac{\cos \theta_{\max }\left(R_0, S^{\perp}\right)}{\sqrt{k}} \frac{\min \left(\varphi, \varphi^2 /\left\|w_{\star}\right\|_2\right)}{\left(1+\left\|w_{\star}\right\|_2\right)^2}-\epsilon\right), $$ where $\varphi^2=\left|\left(v_0^{\top} v_{\star}\right)^2-\left(v_{\star}^{\top} v_{\star}\right)^2\right|$ is defined to be inital head alignment error and $\epsilon \geq d\left(B_0, B_{\star}\right)$ is the error in the pretrained feature extractor. ### Linear probing vs. fine-tuning **Assumption 3.4 (ID subspace assumption).** They assume that the ID data lies on an m-dimensional subspace $S$ where $k<m<d-k$, and they have $n \geq m$ training examples. Formally, let $P_z$ be a distribution on $\mathbb{R}^m$ which has density, and let the columns of $F \in \mathbb{R}^{d \times m}$ form an orthonormal basis for $S$. Then $P_{\text {id }}$ has the distribution of $F z$ where $z \sim P_z$. Recall that the ID error is the expected mean-squared error over the ID distribution $P_{\mathrm{id}}$ : $$ L_{\mathrm{id}}(v, B)=\underset{x \sim P_{\mathrm{id}}}{\mathbb{E}}\left[\left(v_{\star}^{\top} B_{\star} x-v^{\top} B x\right)^2\right] $$ **Theorem 3.5 (Informal version of Theorem A.8).** In the linear overparameterized setting, under the ID subspace assumption (Assumption 3.4), if $\cos \theta_{\max }\left(R_*, S\right) \neq 0$ and $\cos \theta_{\max }\left(R_*, S^{\perp}\right) \neq 0$ where $R_*=$ $\operatorname{rowspace}\left(B_{\star}\right)$, then, $$ \frac{L_{\mathrm{ood}}\left(v_{\mathrm{lp}}^{\infty}, B_0\right)}{L_{\mathrm{ood}}\left(v_{\mathrm{ft}}(t), B_{\mathrm{ft}}(t)\right)} \stackrel{p}{\rightarrow} 0, \text { as } B_0 \rightarrow B_{\star} \text {. } $$ This holds for all times $t$ for $F T$ ( and therefore also for the limit $v_{\mathrm{ft}}^{\infty}, B_{\mathrm{ft}}^{\infty}$ ) and the $L P$ iterates converge to $v_{\mathrm{lp}}^{\infty}, B_0$ as a result of the gradient flow on a convex problem. **Proposition 3.6.** In the linear overparameterized setting, under the ID subspace assumption (Assumption 3.4), let $R_0=$ rowspace $\left(B_0\right)$, and $R_{\mathrm{aug}}=\operatorname{Span}\left(\left\{w_{\star}\right\} \cup R_0\right)$. Suppose $w_{\star} \notin R_0, \cos \theta_{\max }\left(S, R_{\mathrm{aug}}\right) \neq 0$, and that fine-tuning converges to a local minimum of its loss, then fine-tuning does better ID almost surely: $L_{\mathrm{id}}\left(v_{\mathrm{ft}}^{\infty}, B_{\mathrm{ft}}^{\infty}\right)<L_{\mathrm{id}}\left(v_{\mathrm{lp}}^{\infty}, B_0\right)$ with probability 1 (over the randomness of the training examples). To summarize, this section proved that there are tradeoffs between ID and OOD error: FT has lower ID error but higher OOD error than LP. ### Linear probing then fine-tuning: a simple variant to mitigate tradeoffs Proposition $3.7$ is just a first cut result to illustrate that if initialized well, full fine-tuning does not distort features. Proposition 3.7. Suppose having perfect pretrained features $B_0=U B_{\star}$ for some rotation $U$. Let $R_0=$ rowspace $\left(B_0\right)$. Under the non-degeneracy conditions $\cos \theta_{\max }\left(R_0, S\right) \neq 0, \cos \theta_{\max }\left(R_0, S^{\perp}\right) \neq 0$ : $$ \begin{aligned} & \forall t, L_{\mathrm{ood}}\left(B_{\mathrm{ft}}(t)^{\top} v_{\mathrm{ft}}(t)\right)>0, \text { if } v_0 \sim \mathcal{N}\left(0, \sigma^2 I\right) \text { is randomly initialized }(F T), \\ & \forall t, L_{\mathrm{ood}}\left(B_{\mathrm{ft}}(t)^{\top} v_{\mathrm{ft}}(t)\right)=0, \text { if } v_0 \text { is initialized to } v_{\mathrm{lp}}^{\infty}(L P-F T) . \end{aligned} $$ The case where wtheye do not have perfect features $\left(d\left(B_0, B_{\star}\right)>0\right)$ is challenging to analyze because except in very special cases, there is no closed form for the fine-tuning iterates $\left(v_{\mathrm{ft}}(t), B_{\mathrm{ft}}(t)\right)$. And proof of Theorem $3.3$ leveraged invariants to show a lower bound on the error of fine-tuning when $v_0$ and $v_{\star}$ are different, but they were not able to show an upper bound. ## Experiments The datasets: 1. DomainNet: domain adaptation dataset 2. Living-17 and Entity-30: sub-population shift datasets from the BREEDS benchmark 3. FMoW Geo-shift 4. CIFAR-10 $\rightarrow$ STL 5. CIFAR-10 $\rightarrow$ CIFAR-10.1 6. ImageNet-1K **Pretraining and models. They use a CLIP pretrained ViT-B/16 for ImageNet.** **For the other datasets they use a ResNet-50 architecture.** ### Linear probing vs. fine-tuning ![](https://i.imgur.com/3yqvXnY.png) ![](https://i.imgur.com/DzaQGS8.png) ### Linear probing then fine-tuning (LP-FT) ![](https://i.imgur.com/dG7Iq1i.png) ### Examining the feature distortion theory ![](https://i.imgur.com/XRnwsK2.png) ![](https://i.imgur.com/vFnOcUm.png) ![](https://i.imgur.com/VoghZJH.png) ## Related work and discussion - Fine-tuning vs. linear probing - FT is therefore the method of choice for improving accuracy - LP is used to analyze properties of representations - The benefit of preserving pretrained features: This paper work shows something stronger: at no point in the fine-tuning process does FT outperform LP - Mitigating ID-OOD tradeoffs - Theoretical analysis of transfer learning: Prior works also focus on ID error, while this paper analyzes OOD error. ## Conclusion This paper shows theoretically and empirically that **preserving features might be important for robustness**, and simpler approaches like **linear-probing can improve out-of-distribution (OOD) performance.** Also it introduces some tools and ideas for dealing with the main challenge of characterizing properties of the trajectory from a specific initialization in the presence of multiple global optima (implicit regularization effect of initialization). Finally, they showed **LP-FT can mitigate tradeoffs between ID and OOD accuracy in their context.**