###### tags: `論文摘要` `model` `latent state`
# Representation learning for neural population activity with Neural Data Transformers
* 時間: 2021
* Conference: bioRxiv
* Link: https://www.biorxiv.org/content/10.1101/2021.01.16.426955v3
* MLA: Ye, Joel, and Chethan Pandarinath. "Representation learning for neural population activity with Neural Data Transformers." bioRxiv (2021).
## 概論
Neural populations are theorized to **have an underlying dynamical structure**. This structure could be modeled using **dynamic systems**, such as RNNs (e.g. LFADS).
These models help **relate neural population activity to behavior in individual trials** better than traditional baselines (like average activity across trails, smoothing or GPFA).
RNNs are often used in modeling language. However, it takes more time than Transformer because of their **recurrent property**.
Therefore, in this paper, they introduce the Neural Data Transformer (NDT) for modeling neural population spiking activity. This model is basically based on the **BERT encoder**. Different from language dataset, neuroscientificdatasets are **much smaller** than nature language process datasets.
On the monkey dataset, NDT perform **more accurately** infer neural population firing rates on single trial, and **6.7x faster** than LFADS.
## 方法
NDT **transform sequences of binned spiking activity into inferred firing rates** same as LFADS does.

In real-time applications, the sequence of spiking activity would **come from a rolling window of recent activity** that **ends with the current timestep**.
The inferred rates from both models are computed by **Poisson likelihood-based training objective (negative log-likelihood, NLL)**, which comparing against the **observed spiking activity**.
### Self-attention Block


Self-attention compute three three representations from each $n$-dimension embedding input ($x_i$): a query ($q_i$), key ($k_i$), and value ($v_i$) to generate output ($y_i$) by **attending input with weight ($w_i$)**.
$$
y_i=\sum_{j=1}^T{w_i^jv_j}\\
s^j_i=q_i\cdot k^j\\
w_i^j=\frac{\exp(s_i^j)}{\sum_{l=1}^T{s_i^l}}
$$
### Unsupervised Training

The model is given an input seqence $x_1...x_T$ and **randomly mask its tokens by a fixed ratio**.
Therefore, the model is asked to **reproduce the orginal inputs** within the masked ones and must learn how to **trade-off the influence** between high spike counts and low spike counts (**similar with coordinated dropout**).
* zero mask (dropout)
* use intensive regularization to stabilize training
## 實驗結果
### Lorenz and Chaotic RNN Synthetic Datasets
* train-val spilt: 0.8 and 0.2
* 1560 trails, 50 timesteps, 29 channels for Lorenz
1300 trails, 100 timesteps, 50 channels for chaotic RNN
* $R^2$ is calculated by flattening timesteps and trials, and averaging across input channels


### Monkey J Maze dataset
* 202 neurons from primary motor and dorsal premotor cortices
* 2296 trials across 108 different reach conditions
* train on autonomous preiod (250 ms bef. movement to 450 ms aft.)
* 10 ms as bin width
* ridge regression with $\alpha=0.01$
#### Peri-stimulus Time Histograms (PSTHs)
By averaging inferred rates across repeated trials of the same reach condition. Gaussian filter with 30 ms standard deviation kernel is used to be the smoothing baseline.
Vertical bar denotes spikes/sec.

#### Decoding Performance


