# Saccade-Transients
## 2023-11-24
Q: Are we beating the peri-saccadic average?
like the EV metric, but for saccade
## Update 2023-01-05
Here is a quick update on what I've been working on.
Outline:
1. ModelWrapper format for training/shifting/adapting/etc.
2. Pytorch Native version
3. Model interpretability
4. Next steps
### 1. ModelWrapper
In the existing editions of the code, I'd been using an inherited style of building models. There was an `Encoder` class that contained a `loss` property, and a `training_step` and `validation_step` method. Any model we would want to work with would have to inherit that class.
I decided to try implementing a **Wrapper** style training module. So, instead of builing all models to inherit `Encoder`, we just use `nn.Module` and then we can pass that model into a `Wrapper` that knows how to compute training/validation losses and do other things (e.g., shifting the stimulus, or modulating the model with other covariates).
Take shifting as an example: the **Shifter Model** that I used to learn to corret the stimuli inherited the Encoder class.
That architecture had a convolutional **core** and a **Point Readout** style where each (real) neuron had 2 parameters specifying its spatial position in the convolutional output. Those positions could be shifted around by the shifter network. The depiction of that is in the figure from my paper.

For some reason I thought it was computationally more straightforward to apply the shifter on the Readout during training when what we really wanted was a way to shift the stimulus. So I would shift the readout during training and after training, I'd correct the stimulus by shifting it and then I could train any model.
An alternative approach would be to just shift the stimulus.
The two benefits to this are:
1. ANY model can be a shifter model. It just needs a shifter plopped in front of it.
2. The shifter can perform affine transformations of the stimulus (scaling and rotation included instead of just translation).
3. This gets us a lot closer to the "adapter" idea
Because I've been working on improving the data preprocessing pipeline, I wanted to test out improving the learning of the shifters. To this end, I wrote an affine shifter wrapper and a model wrapper class it inherits.
In the `shifters` branch of foundation and the `shifters` branch of models, you can find the new [ModelWrapper](https://github.com/VisNeuroLab/models/blob/shifters/base.py) class and the [Shifter](https://github.com/VisNeuroLab/models/blob/shifters/shifters.py) class.
**I think we should probably walk through the code together and think about how we might better organize it, but I quite like this formulation.**
### 2. Native
I tried to get a working pytorch native model and I had success in my `saccade-transients` repo, but when I copied essentially the same code over to`foundation`, it didn't work. I don't fully understand what's up, but things seem more brittle in foundation.
Ultimately, I realize that there are A LOT of small things we implemented into NDNLayer that are really nice and we would have to reinvent if we went native.
I think the main problems with NDNLayer right now is that we put the nonlinearity and normalization inside there. But those are flags that default to false, so I'm leaning towards using the NDNLayer base and building blocks around it like one would with pytorch native. That seems like a more pythonic style of building. NDNLayer should not handle everything inside it. It should just do the linear part of the operation (weighted sum or convolution)
### 3. Model interpretability
So, I can regularly get CNNs that mostly fit the transients. There has to be some need for extra-retinal stuff (and we'd like to quantify that), but we also want to be able to use these models to show HOW the neurons are gaining this selectivity.
How do we know what the models are doing? This is obviously a much bigger question than we can answer and there are [huge efforts](https://www.anthropic.com/) devoted to this general question. For simple feedforward CNNs this is more straightforward than for LLMs, transformers, or generative models. Basically, the tools we already have available fall into two fundamental categories: **feature visualization** and **wiring diagrams**.
**feature visualization** means looking at the features that the neurons in the network are selective for. There are number of ways to do this: MEI, white noise analysis, visualize weights.
**wiring diagrams** seeks to find small sub networks that exist within the bigger CNN that explain specific selectivity.
The
[hamming window paper](https://openaccess.thecvf.com/content/ICCV2021/papers/Tomen_Spectral_Leakage_and_Rethinking_the_Kernel_Size_in_CNNs_ICCV_2021_paper.pdf)
### Next steps
Scientific ideas:
* Parametric front end
* recurrency using RESNET blocks
*
[Soft weight-sharing RESNET](https://arxiv.org/pdf/1902.09701.pdf)
#### Parametric "front end"
Main idea is that if the transients really are driven by
## Debugging notes 2022-12-01
After a good amount of debugging, I've had success fitting transients again! I've modified 3 parts of code that I'll highlight here and then explain what I found below:
1. Datafilters
2. Training
3. Architectures
### Datafilters
The problem we had before was that too much data was being thrown out. I finally understood what was going wrong.
The original datafilter code fits a running median of the firing rate for each cell, flags deviations, and then eliminates the cell from the dataset if too much of the data is thrown out.
Before, we were seeing a big jump in the data happening about halfway through and almost half the cells were being thrown out.
It turns out that this jump was actually the result of the different stimulus conditions.
Note that when we call `datasets.Pixel`, we request which conditions are loaded.
```
requested_stims=['Gabor', 'BackImage']
```
The problem was that datafilters was correctly identifying that something big had changed in the neural responses, but what had changed was that the experiment changed. The stimulus had changed from `Gabor` to `BackImage` and as a result, we were throwing out the neurons that were most sensitive to this change.
The way I *fixed* it is a cluge at best. I take advantage of the fact that natural images are shown interleaved between Gabor trials, and as a result, if I compute datafilters using the real time indices of the data (instead of the requested stimulus-based indices), then the two conditions will be interleaved and these jumps won't be so obvious.
The key lines in `compute_datafilters` are:
```
fr, _, ft, ftend = self.get_firing_rate_batch(batch_size=batch_size)
fr = fr*240
ind = np.argsort(ft)
frsort = fr[ind,:]
df = np.zeros(fr.shape)
for cc in range(fr.shape[1]):
df[:,cc] = firingrate_datafilter( frsort[:,cc], Lmedian=Lmedian, Lhole=Lhole, FRcut=FRcut, frac_reject=frac_reject, to_plot=verbose, verbose=verbose )
jind = np.argsort(ind)
df = df[jind,:]
```
I compute the datafilters per cell on a sorted version of the firing rates and then unsort at the end.
Now, this doesn't solve the problem generally and more should be done to make sure that we handle the different stimuli differently. I'm flagging this here, so we can decide what to do about that, but let's not let that get in the way of moving forward now...
### Training
To figure out why things are so off, I started with the question: why can't I perfectly fit the mean of the transients?
To do that, I created a model that has literally 2 covariates: a fixation onset covariate, and a slow drift covariate.
Here's the drift term for each unit. This captures how the firing rate varies over the course of the session. Anything hovering around zero means the firing rate didn't change much from the baseline. Linear trends mean the cell either increased or decreased or decreased its firing rate and jumps mean something changed abruptly.

And here is the fixation onset covariate for each cell:

Unsurprisingly, this model can perfectly fit the mean of the transients:

And it actually does a pretty good job of explaining the validation set overall. All just by guessing the mean response every time a saccade happens.

The nice thing about this model is it is small and can be fit very rapidly with LBFGS.
So, here's the idea: always fit models after initializing with the "drift+saccades" baseline model. GLMs, GQMs, CNNs... they should all start with the initial model.
There's a bigger question of why the Adam optimizer doesn't converge to find this solution when fitting these parameters as part of a CNN. In the short term, I don't need to know the answer to that. Right now, we want 1) a model that works and 2) ways to inteperet it. We can worry about how/why fitting works once we already have 1 and 2 working (and therefore, have some results we can talk about/write up)
#### Notes on training and test set contamination:
I double-checked that the training set is generated by randomly sampling fixations and that no test data comes from the same fixations as were included in the training set. The main idea here is to avoid cross-validation bleedthrough by sampling time points from the same individual fixations that contributed to the training set.
I caught a bug in the way I was handling the `maxsamples` variable. `train_inds, val_inds = ds.get_train_indices()`was correclty sampling from entire fixations, but in the subsequent code where I threw out samples based on what would fit on the GPU, I was injecting problematic indices where I had randomly sampled training and test sets from the fixations. I fixed this by adding a `max_sample` flag to `get_train_indices`
```
train_inds, val_inds = ds.get_train_indices(
max_sample=int(0.85*maxsamples))
```
#### Notes on training modifiers:
After initializing model fits by including *modifiers*
*modifiers* - Offsets / Gains operating on the readout stage and corresponding to time in the experiment and fixation onset. These should be able to perfectly fit the transients after the fixation onset because there is literally a term that can capture the mean effect (or a multiplicative gain effect)
### Architectures
The model I originally introduced you to had an architecure like this:

The input has $24$ time lags embedded as the channels of a $35\times35$ dimensional image. There are 3 feedforward convolutional layers. Convolutions are all 2D. Time only exists in the first layer (because it's embedded in the channel dimension). Each convolutional block consists of 3 stages: 2D conv ('same' padding) -> ReLU -> Batch Norm.
In the example above, the layers each have 20 units in them and they have $5\times5$, $5\times5$, and $9\times9$ convolutional kernels, respectively. There are $54$ outputs corresponding to $54$ neurons in the dataset.
We concatenate the output of each layer of the convolutional layers into a $60\times35\times35$ dimensional tensor, and then extract $54$ points by bilinearly interpolation via the pytorch function `grid_sample`. This gives us a $60$ dimensional vector for each neuron that we then multiply by a $60$-dimensional feature vector. Finally, we add a bias term and pass through a [SoftPlus](https://paperswithcode.com/method/softplus) nonlinearity.
I have figures from before showing that this architecture can learn to fit the neurons well and that it can produce saccade transients.
Here are 3 example neurons and the model fits. The CNN is the architecture described above, but with fewer units.

The exciting bit for me was that I could get small neural networks to produce this type of transient, and the first layer tended to split into units that looked like known classes of retinal ganglion cells (RGCs) in the primate retina.

An older and more straighforward architecture is to just have a feedforward CNN that leads to a dense readout. The main issue with that type of architecture is that it produces a huge number of parameters in the readout (in this case it would be $20\times35\times35\times54=1,323,000$). I discuss this in a [blog post](https://jake.vision/blog/lurz-paper) from a few years ago and I really thought that the readout layer with a 2D point per neuron was a good idea, but let's at least role back on iteration and go with a dense readout that is factorized over space and feature (as described in [this paper](https://proceedings.neurips.cc/paper/2017/file/8c249675aea6c3cbd91661bbae767ff1-Paper.pdf))

This reduces the parameters in readout from $1,323,000$ to $20\times54 + 35\times35\times54 = 67,230$ which is a HUGE reduction. It's not as big a reduction as the point readout, but it's more standard. So that's what we're going to try for now as a comparison point.
## Results
I trained several models of different sizes and with point readouts or dense readouts.
The figure below shows the test likelihood for the different depths/sizes of neural networks plotted as pairs where the point readout and dense readout are next to eachother.

Each point is the likelihood for an individual neuron. One thing that stands out to me is the the semi-transparent points (corresponding to the dense readout) are usually slightly higher than the solid points. And, the best model is "big" which meant it had a lot of subunits per layer.
Some of these models produced transients, but I decided to focus on the best models. Either the 3 or 4 layer dense model.
So I started with 3 Layer Big Dense and I fit several versions of the model with and without modifiers (saccade onset , drift)
With 5 runs of fitting a CNN with and without modifiers, we can see that not including modifiers does a little worse, but not by much.

And some of these fits create transients! Including the models that have no modifiers. In the plot below, we have the same models as above, but the NoMod models have solid lines and their paired modifer version has dashed lines.

So what do these look like?
**Layer 1**

These mostly looks good. We see small, punctate units, some with gabor like spatial filters, and lots of different spatial and temporal scales of selectivity. I think this could be reduced and I have ideas of how to do that.
**Layer 2**

Interesting spatial structure, but largely uninterpretable. We'll need MEI analysis working to interpret these...
**Layer 3**

Same as for layer 2.
**Spatial Readout**

This looks awesome. It looks like receptive fields. I think this is really promising.
**Feature Readout**

This also looks pretty cool. There's certainly structure. We need better tools for visualizing what is special about features 13:20 and the neurons that are selective for those features (e.g., 20, 22, 24, 28, 30, 32, etc.)
I think we're in business to start unpacking models to understand why/when they generate transients.
## Conclusion
I think we're good to move forward on the following things:
1) fitting more datasets
2) evaluate when models produce good transient responses
3) figure out why models produce transients (using MEI and integrated gradients)
4) Pruning the big models down to (potentially) more interpretable small models
5) Figure out how to include useful modifiers since mine seem kind of borked
To replicate the figures from this analysis use these code bases:
[datasets](https://github.com/jcbyts/datasets): this has the updated Pixel dataset with new datafilters / training indices
[saccade-transients](https://github.com/jcbyts/saccade-transients): this has the models, utilities, and scripts for fitting and plotting
* [This script](https://github.com/jcbyts/saccade-transients/blob/main/scripts/test_CNNdense_saccade_transients.py) should be the first one you check out
[NDNT](https://github.com/NeuroTheoryUMD/NDNT): I'm still using it for handling regularization, but I have an easy way to pull that module out that I can show you later in the week.
Note: I just noticed I've been commiting and pushing to github as you because I never changed the user on the desktop computer. I'll fix that now.
## More scientific background
The overarching goal of this line of research is that we want a good model of the visual cortex of primates during natural vision. Surprisingly, after half a century of studying visual cortex (including Nobel prizes awarded to Hubel and Wiesel) we don't have a solid theory of what visual cortex accomplishes, and worse, we don't even have a working model that generates responses that mimic the real responses of neurons to natural conditions. Yes, we have lots of models that do a good job for artifical conditions, but we fall short when it comes to anything like natural conditions (i.e., movies, with eye movements)
One of the most striking features of natural vision is that virtually every animal collects visual information with a "saccade and fixate" pattern of eye movements. Saccades are ballistic eye movements that generate rapid shifts of the retina. Many people have shown that saccades have a large effect on [perception](https://www.sciencedirect.com/science/article/pii/S0166223600016854?casa_token=FvKjMxgJz7kAAAAA:ehcXedNWRMyQa3M_59Cqvyd9qjiXXlRlxTJeUplzQDPG_CASolNvYS35ddd1dqd0fPDhxjhxns09) and [neural responses](https://www.sciencedirect.com/science/article/pii/S0959438811000808?casa_token=p4hUrK5Ny70AAAAA:CFXpFESDe32lhY9kvOaOsORioedHAvnU2vXYoLRAQJIa0oeWFGRLRtK1xngX3vD4gXOIrJ3_Z1Wr).
More recently, [we and others](https://www.biorxiv.org/content/10.1101/2022.08.23.504847v1.abstract) showed that saccades on natural images generate large response that relate to gaze shifts and not just to the eyes moving. This is in contrast to the [Miura and Scanziani Nature paper](https://www.nature.com/articles/s41586-022-05196-w) that shows a large extra-retinal component.
Now, we know there IS an extra-retinal component but we don't know what information it contains and, psychophysically, it is best described as suppressive. So, what information is in the transient? How visual is it?
In our paper, we showed that in mouse the saccade modulation has latencies that depend on the stimulus and the tuning of the neurons.


In the dark, most cells becaome suppressed only. With visual stimuli, they have big transients and the timing depends on the tuning of the cells. This is consistent across mouse and marmoset suggesting it's a general principle.
Here's the thing: the linear receptive field model does not produce transients of this scale.
Biphasic linear filters will produce transients, but that's not a good model of the neurons. The biphasic stage is in the retina. And importantly, there are 2 primary cell classes: a sustained temporal response with tiny RFs (Parvo) and a fast transient response with larger RFs (Magno).
So, I think Magno and Parvo really should fall out of a good description of the visual aspect of transients. Some mixture of magno and parvo pathways should be what's generating these responses.
And this is textbook. Here's a figure from Kandel and Schwartz (THE neuroscience textbook). This is a simplified schamtic of the visual system. We're modeling V1 and we will soon be modeling the LGN, V2, V4, and MT.
This is the hierarchy that inspired Yann LeCun (among others) to build the CNNs they did and it should look pretty similar to diagrams of Feedforward nerual networks.


Building parrallel pathways and recurrent computations in will be really cool extensions of these models, but, the MEI method and integrated gradients will be more complicated. The firstorder goal now is to focus on results using simple CNNs and start building better tools for interpreting these models.
Anyway, all this is to say: we *should* see "magno-like" and "parvo-like" pathways in our networks and that should be part of what's driving the transients (I think). I'd be interested if there are very different ways to explain these responses, but I think this is the likely outcome (and a really easy result to write up).