<!-- <img src="https://i.imgur.com/PP10FkM.png" alt="clip performance" width="40%" height="50%"> <img src="https://thdaily.s3-us-west-1.amazonaws.com/gif_20200719232646.gif" alt="clip performance" width="40%" height="50%"> ![](https://media.tenor.com/D12KYBUCOBAAAAAd/git-merge.gif) --> <!-- .slide: data-text-color="black" data-transition="zoom" --> ## Git Re-Basin: Merging Models modulo Permutation Symmetries [`Samuel K. Ainsworth`](https://samlikes.pizza/), [`Jonathan Hayase`](https://scholar.google.com/citations?user=YuBhA1wAAAAJ&hl=ja), [`Siddhartha Srinivasa`](https://goodrobot.ai/). presented by [Albert M. Orozco Camacho](https://twitter.com/alorozco53)_ ---- ![](https://media.tenor.com/D12KYBUCOBAAAAAd/git-merge.gif) --- <!-- .slide: data-transition="zoom" data-background="red"--> ## Motivation ---- - _How do neural networks relate within each other at the layer level?_ - The authors propose several methods to explain layer and activation differences for any two NN's trained on the same task. - Applications of **model merging** span regimes where efficient training is crucial, such as distributed and federated learning, as well as, continual learning. ---- Some research questions: 1. _Why does SGD thrive in optimizing high-dimensional non-convex deep learning loss land-scapes despite being noticeably less robust in other non-convex optimization settings?_ 2. _What are all the local minima? When linearly interpolating between initialization and finaltrained weights, why does the loss smoothly and monotonically decrease?_ ---- 3. **_How can two independently trained models with different random initializations and databatch orders inevitably achieve nearly identical performance?_** ---- <!-- .slide: data-transition="zoom" data-background="green"--> ### Searching for a Common Basin ---- <!-- .slide: data-background="white"--> ![](https://i.imgur.com/08expwj.png) Note: - Most SGD solutions belong to a set whose elements can be permuted so that no barrier exists on the linear interpolation between any two permuted elements. - this fact tries to answer research question 1! - We refer to such solutions as being _linearly mode connected_ (LMC). ---- <!-- .slide: data-background="yellow"--> ### Paper Contributions ---- - <!-- .element: class="fragment" --> <b>Matching Methods</b> - 3 novel algorithms for weight alignment of two independently trained models - <!-- .element: class="fragment" --> <b>Relationship to SGD</b> - using a counterexample, it is argued that linear mode connectivity (LMC) is a property of SGD training ---- - <!-- .element: class="fragment" --> <b>Experiments, including zero-barrier LMC for ResNets</b> - empirical study of LMC modulo permutation symmetries - demonstration of zero-barrier LMC between two ResNets - relationship between LMC and model width - proposed methods are able to combine models into a better (merged) one --- <!-- .slide: data-transition="zoom" data-background="blue"--> ## Background ---- <!-- .slide: data-background="white"--> Consider a $L$-layered MLP ![](https://i.imgur.com/GRfp95Q.png) Permutations are applied layer-wise via a permutation matrix $\mathbf{P}$ ![](https://i.imgur.com/QH9qdom.png) The authors denote as $\Theta'$ to the parameter set of a model, identical to $\Theta'$ except for ![](https://i.imgur.com/EX18oGi.png) Note: - Method can be adapted to arbitrary model architectures - we are interested in studying _permutation symmetries_ ---- <!-- .slide: data-transition="fade"--> > Two models are functionally equivalent when $f(\mathbf{x}; \Theta) = f(\mathbf{x}; \Theta')$ Some _claims_: - There is an entire equivalence class of functionally equivalent weight assignments. - Convergence to any one specific element of this equivalenceclass, as opposed to any others, is determined only by random seed. ---- <!-- .slide: data-background="white"--> ### Task _Given $\Theta_A$ and $\Theta_B$, can we identify some $\pi$ such that when linearly interpolating between and $\pi(\Theta_B)$, all intermediate models enjoy performance similar to $\Theta_A$ and $\Theta_B$?_ ---- <!-- .slide: data-transition="zoom" data-background="white"--> ![](https://i.imgur.com/auBAjZ5.png) ---- <!-- .slide: data-transition="zoom" data-background="brown"--> ### Permutation Methods ---- <!-- .slide: data-transition="zoom" data-background="white"--> #### Matching Activations Hebbian mantra: _"[neural network units] that fire together, wire together"_. ---- <!-- .slide: data-transition="zoom" data-background="white"--> - The authors propose associating units across two models by performing regression between activations. - A linear relationship may exist between the activations of the two models. ![](https://i.imgur.com/AyMj2eA.png) where $\langle \mathbf{A}, \mathbf{B} \rangle_F = \sum_{i,j} A_{i,j}B_{i,j}$ denotes the Fröbenius inner product between real valued matrices. ---- <!-- .slide: data-background="white"--> - This constitutes a _linear assignment problem_ (LAP) for which efficient, practical algorithms are known - A permutation is applied to weights of model $B$ to match model $A$ with ![](https://i.imgur.com/31GkR3K.png) Note: - This is comutationally efficient, as we only need a single pass over the training dataset. - Note, and recall, that the activation matching at each layer is independent of matching at every other layer. ---- <!-- .slide: data-transition="zoom" data-background="white"--> ### Matching Weights - Matching activations' idea can be extended to weights: - if two rows of a weight matrix are equal, they would compute exactly the same feature (ignoring bias terms). - The idea is to associate rows and columns such that $[\mathbf{W}_l^{(A)}]_{i,:} \approx [\mathbf{W}_l^{(B)}]_{j,:}$ for any layer $l$. ![](https://i.imgur.com/WteF7UV.png) ---- ![](https://i.imgur.com/0toUeZb.png) - This yields a _sum of bilinear assignments problem_ (SOBLAP) which is more complicated than LAP, - matching columns and rows at the same time increases difficulty, - it actually known to be a NP-hard problem! ---- - The authors opt to simplify such problem by _greedily_ optimizing a matrix $\mathbf{P}_\mathcal{l}$, while holding others fixed: ![](https://i.imgur.com/qZ6HOqR.png) ---- <!-- .slide: data-background="white"--> ![](https://i.imgur.com/stGbHre.png) ---- <!-- .slide: data-transition="zoom" data-background="white"--> ### Straight-Through Estimator Permutations <img src="https://d33wubrfki0l68.cloudfront.net/0bc2a98cefa6eb619f84ced904f0ffc9d79eb543/4eb1a/images/intuitive-explanation-of-ste-with-code/ste-visualization.png" alt="clip performance" width="80%" height="50%"> <font size="-1">Taken from [here](https://www.hassanaskary.com/python/pytorch/deep%20learning/2020/09/19/intuitive-explanation-of-straight-through-estimators.html)</font> ---- - The authors attempt to _"learn"_ the ideal permutation of weights $\pi_(\Theta_B)$ ![](https://i.imgur.com/xBxgmhA.png) where $\tilde{\Theta}_B$ denotes an approximation of $\pi_(\Theta_B)$. - Note that $\text{proj}(\cdot)$ is non-differentiable! ---- - Yet we can overcome this issue by paremterizing $\tilde{\Theta}_B$ - In the **forward pass**: - $\tilde{\Theta}_B$ is projected to the closest realizable $\pi_(\Theta_B)$ - In the **backward pass**: - $\tilde{\Theta}_B$ are used for backpropagation. ---- <!-- .slide: data-background="white"--> ![](https://i.imgur.com/mJ6EDPt.png) ---- <!-- .slide: data-background="beige"--> ### Implementation ---- https://github.com/samuela/git-re-basin ```python def weight_matching(rng, ps: PermutationSpec, params_a, params_b, max_iter=100, init_perm=None, silent=False): """Find a permutation of `params_b` to make them match `params_a`.""" perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()} perm = {p: jnp.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm perm_names = list(perm.keys()) for iteration in range(max_iter): progress = False for p_ix in random.permutation(rngmix(rng, iteration), len(perm_names)): p = perm_names[p_ix] n = perm_sizes[p] A = jnp.zeros((n, n)) for wk, axis in ps.perm_to_axes[p]: w_a = params_a[wk] w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1)) w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1)) A += w_a @ w_b.T ri, ci = linear_sum_assignment(A, maximize=True) assert (ri == jnp.arange(len(ri))).all() oldL = jnp.vdot(A, jnp.eye(n)[perm[p]]) newL = jnp.vdot(A, jnp.eye(n)[ci, :]) if not silent: print(f"{iteration}/{p}: {newL - oldL}") progress = progress or newL > oldL + 1e-12 perm[p] = jnp.array(ci) if not progress: break return perm ``` --- <!-- .slide: data-background="purple"--> ### Experiments ---- <!-- .slide: data-transition="zoom" .slide: data-background="white"--> #### Loss Barriers ![](https://i.imgur.com/xiOVNL4.png) - This gives a way to measure how error varies from one model to another, modulo permutation, by varying the $\lambda \in [0, 1]$ parameter. ---- <!-- .slide: data-transition="zoom" .slide: data-background="white"--> ![](https://github.com/samuela/git-re-basin/raw/main/mnist_video.gif) ---- <!-- .slide: data-transition="zoom" .slide: data-background="white"--> #### Naïve Interpolation ($\pi(\Theta_B) = \Theta_B$) vs Proposed Methods ![](https://i.imgur.com/ybnTdrB.png) Note: - Experiments with MNIST, CIFAR-10, ImageNet - MNIST achieves zero-barrier LMC - We found that weight matching offered a compelling balance between computational cost and performance ---- <!-- .slide: data-transition="zoom" .slide: data-background="white"--> #### Effect of Model Width ![](https://i.imgur.com/9dDTImQ.png) Note: - Author denote a clear relationship emerges between model width and linear mode connectivity, as measured by the loss barrier between solutions. ---- <!-- .slide: data-transition="zoom" .slide: data-background="white"--> #### Model Patching, Split Data Training, and Improved Calibration <font size=-1> Can we use permutations to actually merge models trained on disjoint datasets? </font> <img src="https://i.imgur.com/FGnpM71.png" alt="clip performance" width="45%" height="50%"> Note: - Merging separately trained models did not match the performance of an omniscient model trained on the full dataset or an ensemble of the two models with twice the number of effective weights. - We did manage to merge the two models in weight space, achieving an interpolated model that outperforms both input models, in terms of test loss while using halfthe memory and compute required for ensembling. --- <!-- .slide: data-transition="zoom" data-background="pink"--> ## Epilogue & Interesting Stuff ---- <!-- .slide: data-background="pink"--> ![](https://i.pinimg.com/564x/95/2d/a0/952da03b73f0810c8d58c4087bccb509.jpg) ---- Suppose we have two models with weight parameters $\mathbf{W}^{(A)}$ and $\mathbf{W}^{(B)}$, and we want to align them by minimizing a loss function $L(\mathbf{W}^{(A)}, \mathbf{W}^{(B)})$. We can do this using an optimization algorithm such as gradient descent, which involves iteratively updating the weight parameters according to the following update rule: $$ \mathbf{W}^{(A)} \leftarrow \mathbf{W}^{(A)} - \alpha \nabla_{\mathbf{W}^{(A)}} L(\mathbf{W}^{(A)}, \mathbf{W}^{(B)}) $$ where $\alpha$ is the learning rate and $\nabla_{\mathbf{W}^{(A)}} L(\mathbf{W}^{(A)}, \mathbf{W}^{(B)})$ is the gradient of the loss function with respect to the weight parameters of model A. This process is repeated until the weight parameters of the two models are aligned to the desired level of precision, or until a maximum number of iterations has been reached. ```python def align_models(model_A, model_B, max_iterations, learning_rate): # Initialize weight parameters of model A and model B weights_A = model_A.get_weights() weights_B = model_B.get_weights() # Define loss function def loss(weights_A, weights_B): # Compute loss using desired measure of difference between weights return compute_loss(weights_A, weights_B) # Iteratively update weights of model A using gradient descent for i in range(max_iterations): # Compute gradient of loss with respect to weights of model A grad_A = compute_gradient(loss, weights_A, weights_B) # Update weights of model A using gradient descent weights_A -= learning_rate * grad_A # Check if weight parameters of model A and model B are aligned to desired level of precision if check_alignment(weights_A, weights_B): break # Return aligned weight parameters of model A and model B return weights_A, weights_B ``` There are many different ways to define the loss function $L(\mathbf{W}^{(A)}, \mathbf{W}^{(B)})$ for aligning the weight parameters of two models. Here are a few examples: - Mean squared error: One option is to use the mean squared error between the weight parameters of the two models as the loss function: $$ L(\mathbf{W}^{(A)}, \mathbf{W}^{(B)}) = \frac{1}{n}\sum_{i=1}^n (\mathbf{W}_i^{(A)} - \mathbf{W}_i^{(B)})^2 $$ where $n$ is the number of weight parameters in the models. This loss function measures the average squared difference between the weight parameters of the two models, and can be used to encourage the weight parameters to be as similar as possible. - Weighted mean squared error: Alternatively, you could weight the mean squared error by the size of the weight parameters, to give more importance to the larger parameters: $$ L(\mathbf{W}^{(A)}, \mathbf{W}^{(B)}) = \frac{1}{n}\sum_{i=1}^n w_i (\mathbf{W}_i^{(A)} - \mathbf{W}_i^{(B)})^2 $$ where $w_i$ is the weight for weight parameter $\mathbf{W}_i$. - Cosine similarity: Another option is to use the cosine similarity between the weight parameters as the loss function: $$ L(\mathbf{W}^{(A)}, \mathbf{W}^{(B)}) = 1 - \frac{\mathbf{W}^{(A)}\cdot \mathbf{W}^{(B)}}{|\mathbf{W}^{(A)}|_2 |\mathbf{W}^{(B)}|_2} $$ This loss function measures the angle between the weight parameters, with smaller angles indicating a higher level of similarity. These are just a few examples of loss functions that could be used for aligning the weight parameters of two models. There are many other possibilities, and the choice of loss function will depend on the specific problem you are trying to solve and your goals for aligning the models.
{"metaMigratedAt":"2023-06-17T15:50:45.086Z","metaMigratedFrom":"YAML","title":"Git Re-Basin Merging Models modulo Permutation Symmetries","breaks":true,"contributors":"[{\"id\":\"adb0403f-b4e6-4ebc-be17-cc638e9f5cfe\",\"add\":34600,\"del\":19625}]","description":"Samuel K. Ainsworth,Jonathan Hayase,Siddhartha Srinivasa."}
    1838 views