Cross-Gradient Training === **Title**: Generalizing Across Domains via Cross-Gradient Training **Authors**: Shiv Shankar, Vihari Piratla, Sunita Sarawagi **Year**: 2018 **Link**: <https://arxiv.org/pdf/1804.10745.pdf> How to train a model to predict for unseen domains?? - Proposed solution $\text{CrossGrad}$ does not need a adaptation step with labeld/unlabeld data. - It also does not require any domain specific features from the new domains. - Existing solutions erase domain specific info from samples -- **CrossGrad** is free to use those as long as it does not overfit on training domains. This is achieved by data augmentation tuned to the domain prediction loss. Empirical results show CrossGrad is better than Domain Adversarial Networks. ## Introduction Paper tries to exploit training across multi-domains so as to generalise to unseen domains. Most existing solutions require labeled or unlabeled data from the target domain. There is also a separate "adaptation" step. This often results in "forgetting" all the domain specific information in the sample. The paper tries to exploit training data from multiple domains and predict for out-of-domain data *without an adaptation step.* ### Problem Statement - Full space of domains is $\mathbf{D}$. Subset of domains present in training is $D \in \mathbf{D}$. - Each labeled instance is $(x, y, d)$; $d$ is the domain. Goal is to train a model with $D$ domains at training so that we get good performance on entire space $\mathbf{D}$. - $P(y|x)$ is much harder to learn than $P(y|x, d)$ -- the latter is what models typically learn. - Domain Adaptation Networks learn a transformation for $x$ which makes it domain invariant, hoping that this would make the system domain independent. Simply erasing domain info is not a good idea, since we *also want* good performance in $D$. ### Contribution - Proposed solution does not use any data adaptation. - Relies on **data augmentation based on domain-guided perturbations of inputs.** ![](https://i.imgur.com/AxU8zCg.png) Assume that data is generated from the first Bayesian process from the figure above :arrow_up: - $g$ captures some multi-variate representation of the discrete domain $d$. - If we knew how domain signals from $d$ manifest in $x$, we could simply replace them with those sampled from other domains. - From the diagram, if we can perfectly recover $g$, we can then add pertubations to create an augmented instance $x\prime$. Tough luck. You can't extract $g$ perfectly. - Instead, we train a domain classifier. - Given an instance $x$, take the loss gradient from this network w.r.t. $x$. - This *vector* gives the direction in which the domain classifier loss changes the most. We perturb $x$ with this loss to get $x\prime$. - Finally, the training loss for the $y$-predictor on original $x$ is combined with training loss on $x\prime$ as well. ## Approach The inference task is $P(y | x)$. After removing $d$, we need an estimate over $g$. For $|D|$ discrete domains, the true $P(y | x)$ is given by $$ P(y|x) = \sum_{d\in D} P(y|x, d) P(d|x) $$ > The assumption staates that as long as the training domains span the latent continuous features, we can generalize to new fonts and speakers. > With the assumption, we can rewrite as $$ \int_g P(y | x, g) P(g|x) \approx P(y | x, \hat{g}) $$ where $\hat{g} = \text{argmax}_g P(g|x)$ is the inferred continuous representation of the domain of $x$. How to estimate $P(y | x, \hat{g}), \hat{g}$?? - Main trouble in estimating $P(y | x, \hat{g})$ is that it does not overfit on the inferred $g$ in the training data. - They generalise by moving along the contrinuous space of $g$ to sample examples from new *hallucinated domains*. - Ideally, we should be able to perturn the domain *without* changing the label. To ensure that the perturbation does not change the label the change is kept small along the direction of estimated domain features $g$ while changing the label as little as possible. This is achieved as follows: 1. Train a classifier $G$ to extract domain features from $x_i$ in a supervised manner by predicting its true discrete domain label $d_i$. $J_d(x_i, d_i)$ is the cross-entropy loss of this classifier for that instance. 2. We now need to sample $x_i\prime$ **WHICH HAS THE SAME LABEL AS** $x_i$ but whose "domain" is as far away as possible from $x_i$. - This is achieved by $x_i\prime = x_i + \epsilon \text{GRADIENT}_{x_i}J_d(x_i, d_i)$. - This assumes that the move along the error gradient will NOT change the label. 3. To enforce that label does not change, we train the domain extractor $G$ to avoid domain shifts while giving it data that has label perturbations. ![](https://i.imgur.com/IcioHs4.png)