owned this note
owned this note
Published
Linked with GitHub
# Nobrainer API discussion
Document to discuss the nobrainer API
This has come up in a few issues that we need to start producing a high level API interface for nobrainer, in addition to the lower level interfaces.
_Primary audience_: users of nobrainer models - think scripters more than developers. Other users will have command line interfaces (via nobrainer zoo) that they can use in shell. Those interfaces should also mimic as much of this API as possible.
## Proposal:
### Principles
- Keep API surface as close to scikit-learn and nilearn as possible to maximize community reuse
- https://www.analyticsvidhya.com/blog/2021/04/sklearn-objects-fit-vs-transform-vs-fit_transform-vs-predict-in-scikit-learn/
- https://nilearn.github.io/stable/modules/reference.html
- Limit internal use of for loops to essential and useful scenarios. For example, fit, predict, transform
Any model should support a scikit-learn or nilearn like API:
**Estimator API:** Everything is an estimator and other classes for transformation or prediction are derived from it.
`fit`: estimate or train a model
`fit_transform`: train and apply estimates to input
`load`: load a saved model
`predict`: apply trained model to input
`save`: save the model
`transform`: transform inputs using estimates
`generate`: for generators
`evaluate`: evaluate using a trained model, test inputs and labels. likely a function not a method.
Other existing components of nobrainer
`losses`: loss functions
`metrics`: metrics for evaluation
`models`: these are the core lower level 3d models meshnet/unet/vnet/pgan/ae/vox2vox/cyclegan/simsiam/etc.,
`transformers`: various data augmentation methods
`utilities`: i/o, etc.,.
In addition, we should have models by neuroimaging processing type:
`connectivity`: (e.g, estimation)
`detection`: (e.g., anomaly, defaced)
`embedding`: (e.g., transform image to latent spaces)
`generation`: (e.g. generate images)
`preprocessing`: (e.g. augmentations, quality control)
`quality`: (e.g. quality estimates, control)
`registration`: (e.g., image, point to point, point to image)
`segmentation`: (e.g., brain extraction, any labeling of structures)
`transformation`: (e.g., mapping diffusion to T2)
Unlike scikit-learn we will support tensors with channels.
Potentially a `Pipeline` in case transformers are easier that way.
### Data input:
TFrecords with appropriate metadata about structures
Niimg like objects: see https://nilearn.github.io/manipulating_images/input_output.html
## Scripting examples
1. **Brain extraction**
Training:
```python
from nobrainer.segmentation import Segmentation
from nobrainer.models import unet
# training
model = Segmentation(basemodel=unet,
batchnorm=True,
block_shape=...)
model.fit(dataset, checkpoint_dir="/path/to/checkpoints", batch_size=)
model.save()
```
Internally in model.fit something like this would be called.
```python
basemodel = nobrainer.models.unet(
n_classes=1,
input_shape=(*block_shape, 1),
batchnorm=True,
)
```
Inference:
```python
# This should return a Segmentation model (and thus
# saved model should have class info
model = nobrainer.load_model("/path/to/saved/model")
niimg_like_object = model.predict("niimg like object")
```
2. **Transfer learning**
```python
model = nobrainer.load_model("/path/to/saved/model")
assert isinstance(model, Segmentation)
model.fit(new_dataset,
checkpoint_dir="/path/to/checkpoints",
batch_size=8,
multi_gpu=True)
model.save()
niimg_like_object = model.predict("niimg like object")
```
3. **Generation**
```python
from nobrainer.generation import ProgressiveGAN
model = ProgressiveGAN()
model.fit(dataset, ...)
model.save()
niimg_like_object, latent = model.generate(resolution=256)
```
4. **Registration**
```python
from nobrainer.registration import SynthMorph
model = nobrainer.load_model("/path/to/saved/model")
assert isinstance(model, SynthMorph)
model.fit(source="niimg like object",
target="niimg like object")
model.transform_.to_x5() # save estimated transform to a standard file
```
Or both fit and transform
```python
niimg_like_object = model.fit_transform(
source="niimg like object",
target="niimg like object")
```
## Implementation details and class abstractions