owned this note
owned this note
Published
Linked with GitHub
###### tags: `PaperReview`
# Bayesian Low-Rank Adaptation for Large Language Model
> University of Bristol
> University of Massachusetts, Amherst
## Introduction
- In recent years, fine-tuning large language models (LLMs) have become increasingly important
- However, they often exhibit **overconfidence** which is problematic in **safety-critical applications** or when making decisions in areas where **limited data is available**, such as **medical diagnosis, finance and experimental design**
- **Bayesian deep learning** is commonly proposed as a solution to overconfidence in deep networks
- The advantages of **Bayesian fine-tuning are much clearer** than the advantages of Bayesian pretraining.
- **Fine-tuned models are typically poorly calibrated** even after very large-scale instruction fine-tuning
- **Poor calibration may arise in fine-tuning settings**, as there is often far less data available in fine-tuning than pre-training.
- Develop the **first Bayesian inference method designed specifically for the LoRA parameters** in LLM fine-tuning.
<!--
## Related work
- Past work on integrating Bayesian inference with language models has **usually operated in the largescale pre-training setting**, where the advantages of Bayes are **unclear**, because **pre-training datasets are very large** and large-scale pretrained models seem to be **reasonably well-calibrated even without Bayes**
- -->
## Background
### LoRA
- LLMs have a large number of large weight matrices, denoted $\mathbf{W}_0 \in \mathbb{R}^{n_{\text {out }} \times n_{\text {in }}}$, with inputs $a$ and outputs $h$. Keep $\mathbf{W}_0$ fixed, and introduce a perturbation to the weight matrix, $\Delta \mathrm{W}$,
$$
\mathbf{h}=\mathbf{W}_0 \mathbf{a}+\Delta \mathbf{W a}=\mathbf{W}_0 \mathbf{a}+\mathbf{B A} \mathbf{a} .
$$
- $\Delta \mathrm{W}$ is low-rank as it is written as the product of two matrices, $\mathbf{B} \in \mathbb{R}^{n_{\text {out }} \times n_{\text {lr }}}$ and $\mathbf{A} \in \mathbb{R}^{n_{\mathrm{lr}} \times n_{\text {in }}}$ where $n_{\text {lr }}$ is significantly smaller than $n_{\text {in }}$ or $n_{\text {out }}$ (e.g. 4096), for instance, $n_{\mathrm{lr}}=8$. Therefore, the total number of LoRA parameters for this weight matrix is $n_{\mathrm{lr}}\left(n_{\mathrm{in}}+n_{\text {out }}\right)$.
### Laplace Approximations
- In Bayesian inference for classification or next token prediction, the goal is to **find the full posterior**,
$$
P(\theta | X, y) \propto P(y|X, \theta) \cdot P(\theta)
$$
Where $P(\theta | X, y)$ is the posterior, $P(y|X, \theta)$ is the likelihood (e.g. softmax categorical distribution for classification task), and $P(\theta)$ is the prior $\mathcal{N}(0, \lambda^{-1}\mathbf{I})$.
- Predicting posterior is intractable, so firstly **find the maximum-a-posteriori(MAP) solution**:
$$
\begin{aligned}
\mathcal{L}(\mathbf{y}, \mathbf{X} ; \boldsymbol{\theta}) & =\log \mathrm{P}(\mathbf{y} \mid \mathbf{X}, \boldsymbol{\theta})+\log \mathrm{P}(\boldsymbol{\theta})=\log \mathrm{P}(\boldsymbol{\theta} \mid \mathbf{X}, \mathbf{y})+\text { const } \\
\boldsymbol{\theta}_{\text {MAP }} & =\underset{\boldsymbol{\theta}}{\operatorname{argmax}} \mathcal{L}(\mathbf{y}, \mathbf{X} ; \boldsymbol{\theta}) .
\end{aligned}
$$
- Then the Laplace approach performs a **second-order Taylor expansion** of the log-joint around $\theta_{\text {MAP }}$,
$$
\mathcal{L}(\mathbf{y}, \mathbf{X} ; \boldsymbol{\theta}) \approx \mathcal{L}\left(\mathbf{y}, \mathbf{X} ; \boldsymbol{\theta}_{\mathrm{MAP}}\right)-\frac{1}{2}\left(\boldsymbol{\theta}-\boldsymbol{\theta}_{\mathrm{MAP}}\right)^T\left(\left.\nabla_{\boldsymbol{\theta}}^2 \mathcal{L}(\mathbf{y}, \mathbf{X} ; \boldsymbol{\theta})\right|_{\boldsymbol{\theta}_{\mathrm{MAP}}}\right)\left(\boldsymbol{\theta}-\boldsymbol{\theta}_{\mathrm{MAP}}\right) .
$$
- Since this is now a quadratic function of $\theta$, **approximate the posterior as a Gaussian centered at $\theta_{\text {MAP }}$** with covariance given by the **inverse of the Hessian**,
$$
\begin{aligned}
\mathrm{P}(\boldsymbol{\theta} \mid \mathcal{D}) & \approx \mathcal{N}\left(\boldsymbol{\theta} ; \boldsymbol{\theta}_{\mathrm{MAP}}, \boldsymbol{\Sigma}\right), \\
\boldsymbol{\Sigma} & =\left(\left.\nabla_{\boldsymbol{\theta}}^2 \mathcal{L}(\mathbf{y}, \mathbf{X} ; \boldsymbol{\theta})\right|_{\boldsymbol{\theta}_{\mathrm{MAP}}}\right)^{-1}=\left(\left.\nabla_{\boldsymbol{\theta}}^2 \log \mathrm{P}(\mathbf{y} \mid \mathbf{X}, \boldsymbol{\theta})\right|_{\boldsymbol{\theta}_{\mathrm{MAP}}}+\lambda \mathbf{I}\right)^{-1} .
\end{aligned}
$$
- To **ensure positive definiteness of the covariance**, use the **Fisher information**,
$$
\mathbf{F}(\boldsymbol{\theta})=\sum_{n=1}^N \mathbb{E}_{\mathrm{P}\left(y \mid f_{\boldsymbol{\theta}}\left(\mathbf{x}_n\right)\right)}\left[\nabla_{\boldsymbol{\theta}} \mathrm{P}\left(y \mid f_{\boldsymbol{\theta}}\left(\mathbf{x}_n\right)\right)\left(\nabla_{\boldsymbol{\theta}} \mathrm{P}\left(y \mid f_{\boldsymbol{\theta}}\left(\mathbf{x}_n\right)\right)\right)^T\right],
$$
where the expectation is taken over the model's output distribution.
- But the full Hessian or Fisher is a $P \times P$ matrix, where $P$ is the number of parameter, which is clearly **too large**, even if only considering LoRA (**6M parameters** for r=8 in Llama2-7B).
- So further restructuring to the Hessian is required. In particular, consider either **just the Hessian for the last-layer**, or using **Kronecker-factored(KFAC) structure for individual weight matrices**.
- In KFAC, approximate the Fisher using blocks for each linear layer. For $\ell^{th}$, compute the block by denoting the input as $a_\ell$ and the output as $b_\ell$. Then the Fisher is:
$$
\mathbf{F}_{\ell}=\sum_{n=1}^N \mathbb{E}_{\mathrm{P}\left(y \mid f_{\boldsymbol{\theta}}\left(\mathbf{x}_n\right)\right)}\left[\left(\mathbf{a}_{\ell-1} \mathbf{a}_{\ell-1}^T\right) \otimes\left(\mathbf{g}_{\ell} \mathbf{g}_{\ell}^T\right)\right]
$$
where $\mathbf{g}_{\ell}=\nabla_{\mathbf{b}_{\ell}} \log \mathrm{P}(\mathbf{y} \mid \mathbf{X}, \boldsymbol{\theta})$ is the gradient of the the log-likelihood gradient with respect to the outputs.
- Using Laplace approximations has strong connections to **linearizing the neural network**.
- Then **predicting under the linearized model is more effective** than e.g. sampling the approximate posterior over weights. In particular,
$$
f_{\boldsymbol{\theta}}\left(\mathrm{x}_*\right) \approx f_{\boldsymbol{\theta}_{\mathrm{MAP}}}\left(\mathrm{x}_*\right)+\left.\nabla_{\boldsymbol{\theta}} f_{\boldsymbol{\theta}}\left(\mathrm{x}_*\right)\right|_{\boldsymbol{\theta}_{\mathrm{MAP}}} ^T\left(\boldsymbol{\theta}-\boldsymbol{\theta}_{\mathrm{MAP}}\right) .
$$
where $\mathrm{x}_*$ is a test-input. This approach is also known as the **linearized Laplace approximation**.
- Since the approximated posterior and the linearized model are obtained, **integrate out the posterior on weights and get a Gaussian posterior on output logits**,
$$
f_{\boldsymbol{\theta}}\left(\mathrm{x}_*\right) \sim \mathcal{N}\left(f_{\theta_{\mathrm{MAP}}}\left(\mathrm{x}_*\right), \boldsymbol{\Lambda}\right),
$$
where
$$
\boldsymbol{\Lambda}=\left(\left.\nabla_{\boldsymbol{\theta}} f_{\boldsymbol{\theta}}\left(\mathrm{x}_*\right)\right|_{\boldsymbol{\theta}_{\mathrm{MAP}}} ^T\right) \boldsymbol{\Sigma}\left(\left.\nabla_{\boldsymbol{\theta}} f_{\boldsymbol{\theta}}\left(\mathrm{x}_*\right)\right|_{\boldsymbol{\theta}_{\mathrm{MAP}}}\right) .
$$
- Subsequently, optimize the prior precision $\lambda$ using the **closed form Laplace marginal likelihood (model evidence)** on the training dataset,
$$
\mathrm{P}(\mathbf{y} \mid \mathbf{X})=\int \mathrm{P}(\mathbf{y} \mid \mathbf{X}, \boldsymbol{\theta}) \mathrm{P}(\boldsymbol{\theta}) d \boldsymbol{\theta} \approx \exp \left(-\mathcal{L}\left(\mathbf{y}, \mathbf{X} ; \boldsymbol{\theta}_{\mathrm{MAP}}\right)\right)(2 \pi)^{D / 2}|\boldsymbol{\Sigma}|^{1 / 2},
$$
- Crucially, unlike other post-hoc calibration methods, **post-hoc Laplace does not require a separate validation set.** This feature is particularly **beneficial for small-scale datasets** where training data is scarce.
- To obtain samples of $f_{\boldsymbol{\theta}}\left(\mathrm{x}_*\right)$, we can decompose the covariance using the **Cholesky factorization**, $\Lambda=\mathbf{L L}^T$,
$$
\tilde{f}_{\boldsymbol{\theta}}\left(\mathrm{x}_*\right)=f_{\boldsymbol{\theta}_{\mathrm{MAP}}}\left(\mathrm{x}_*\right)+\mathbf{L} \xi,
$$
where $\xi$ is a vector of IID (Independent and Identically Distributed) standard normal random variables.
- **Compute the Bayesian model average** by computing the average probabilities (passing the sampled logits through softmax function) under the Gaussian random noise from $\xi$.

## Results
- Consider post-hoc Laplace approximations applied to LoRA parameters (Laplace-LoRA) at **model checkpoints** $\theta_{MAP}$ obtained from standard fine-tuning.
- Apply LoRA to **queries, values, and the output layer**.
- Evaluating the **negative log-likelihood (NLL)** and **expected calibration error (ECE)** of LlaMA2-7B during fine-tuning on common-sense reasoning tasks.
- NLL computes the **sum of negative expected log probability of predicting the true label**,
$$
\mathrm{NLL}=\sum_{i=1}^N-\log P\left(\hat{y}_i=y_i\right),
$$
where $P\left(\hat{y}_i\right)$ is the model's output distribution, $y_i$ is the true label.
- NLL is also equivalent to crossentropy between the one-hot data distribution and the model's output distribution. **NLL encourages the model to give high probability to correct answers**.
- On the other hand, **ECE measures the alignment between model's confidence and accuracy**, by binning the highest predicted probabilities and compute a weighted average between the difference of average accuracy and confidence in each bin,
$$
\mathrm{ECE}=\sum_{m=1}^M \frac{\left|B_m\right|}{N}\left|\operatorname{acc}\left(B_m\right)-\operatorname{conf}\left(B_m\right)\right|,
$$
where $\operatorname{acc}\left(B_m\right)$ and $\operatorname{conf}\left(B_m\right)$ are the average accuracy and average confidence in each bin,
$$
\operatorname{acc}\left(B_m\right)=\frac{1}{\left|B_m\right|} \sum_{i \in B_m} 1\left(\hat{y}_i=y_i\right), \quad \operatorname{conf}\left(B_m\right)=\frac{1}{\left|B_m\right|} \sum_{i \in B_m} P\left(\hat{y}_i\right),
$$
and $\left|B_m\right|$ is the number of examples in bin $m$.
- However, **expected calibration error cannot be optimized directly like negative log-likelihood**, as a completely random model will have the same accuracy and confidence for each datapoint, thus achieving zero ECE.


### LA vs LLLA, or Which Layers are Uncertain?
- Plotted the standard deviation of the logits arising from various sources

-
### OOD Evaluation

## Conclusion
- Proposed Laplace-LoRA, for Bayesian parameter-efficient fine-tuning of LLMs
- Require no changes to efficient implementations of the standard fine-tuning process.
- Observed significant gains in expected calibration error and negative log-likelihood, indicating an improved estimation of uncertainty.
- This can be seen as a step towards the development of more reliable and trustworthy LLMs.