## Prerequisites
python basic concept
- magic method
- zip
- iterable, iterator, generator
## Introduction
When training DL models, it is often needed to preprocess data and find a way to feed data into model. The deep learning library `torch` provides two utility class `torch.utils.Dataset` and `torch.utils.DataLoader` to facilitate preprocessing.
## Examples
A typical usage is as follows
0. import
```python=
import torch
from torch.utils.data import Dataset, DataLoader
```
1. Define a custom dataset class and instantiate one object
```python=
class MyDataSet(Dataset):
def __init__(self, x, y) -> None:
super().__init__()
self.x = x
self.y = y
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self) -> int:
return len(self.x)
# batch_size, length, class
B, L, C = 3, 4, 5
x = torch.randn((B, L, C))
y = torch.randn((B))
ds = MyDataSet(x, y)
```
2. Instantiate a DataLoader object
```python=
dl = DataLoader(ds, batch_size=2)
```
3. feed the data into the model
```python=
for batch in dl:
x, y = batch
output = model(x)
...
```
## Custom Dataset
We will usually build our custom dataset inherited from `torch.utils.Dataset`, for which we only need to implement three magic methods in our custom class.
- `__init__()`: store raw data
- `__getitem__()`: implements indexing
- `__len__()`: implements `len()`
## DataLoader
`DataLoader` is a generator that takes a `torch.utils.Dataset` as input when constructing.
It has some important parameters when constructing
- `batch_size`: How many data for this batch
- `shuffle`: shuffle the data or not
- `pin_memory`: pin data inside memory to prevent swapped out, usaully set `True`
- `collate_fn`
### How DataLoader prepare a batch
When iterating, `DataLoader` first collects data from `Dataset` using `__getitem__()` and convert them to `torch.Tensor`, which obtains **a list of tuple**. (A datapoint is a tuple). Finally the list is fed into `collate_fn()` to produce the output.
Let's first see a DataLoader without using `collate_fn()`
```python=
dl_no_collate = DataLoader(ds, batch_size=3, collate_fn=lambda x: x)
raw_batch = next(iter(dl_no_collate))
raw_batch
```
The output is of the form `[(x1, y1), (x2, y2), (x3, y3)]`, a list of tuples.
```python=
[(tensor([[ 1.0475, 0.0485, -0.2177, 1.4522, -0.5661],
[ 0.3699, 1.0205, 0.8040, 0.2916, 0.0446],
[-0.4648, 3.2485, 0.2894, 0.4325, -0.5360],
[-0.7087, 1.1254, 0.0747, -1.0728, -0.7503]]),
tensor(0.9723)),
(tensor([[ 0.0565, -0.8368, 0.9353, -0.9156, -0.4951],
[ 0.3750, 0.5120, -0.9690, 0.4152, -1.1938],
[ 0.9016, 0.0136, -1.0142, 1.8649, -0.1401],
[-0.5947, 0.0313, -0.6279, 1.5463, -0.6546]]),
tensor(0.8549)),
(tensor([[-0.9101, -0.2072, -0.0353, -0.5205, -1.2083],
[-1.6441, -1.0244, 0.1473, -0.3761, -1.0025],
[-1.2569, -0.4252, 0.1050, 1.4469, 0.4834],
[ 0.2454, 0.4732, -0.2637, 0.8446, -1.0841]]),
tensor(1.0406))]
```
The default `collate_fn()` collects same attributes from the list and stack them using `torch.stack()`.
The following code does the exact same as default `collate_fn()`
```python=
def default_collate_fn(data):
'''
data:
[(x1, y1),
(x2, y2),
(x3, y3)]
attributes:
[(x1, x2, x3),
(y1, y2, y3)]
'''
attributes = list(zip(*data))
ret = []
for att in attributes:
ret.append(torch.stack(att))
return ret
```
```python=
dl = DataLoader(ds, batch_size=3, collate_fn=default_collate_fn)
batch = next(iter(dl))
batch
```
The result batch is of the form `(x, y)`
```python=
[tensor([[[ 1.0475, 0.0485, -0.2177, 1.4522, -0.5661],
[ 0.3699, 1.0205, 0.8040, 0.2916, 0.0446],
[-0.4648, 3.2485, 0.2894, 0.4325, -0.5360],
[-0.7087, 1.1254, 0.0747, -1.0728, -0.7503]],
[[ 0.0565, -0.8368, 0.9353, -0.9156, -0.4951],
[ 0.3750, 0.5120, -0.9690, 0.4152, -1.1938],
[ 0.9016, 0.0136, -1.0142, 1.8649, -0.1401],
[-0.5947, 0.0313, -0.6279, 1.5463, -0.6546]],
[[-0.9101, -0.2072, -0.0353, -0.5205, -1.2083],
[-1.6441, -1.0244, 0.1473, -0.3761, -1.0025],
[-1.2569, -0.4252, 0.1050, 1.4469, 0.4834],
[ 0.2454, 0.4732, -0.2637, 0.8446, -1.0841]]]),
tensor([0.9723, 0.8549, 1.0406])]
```
### Pratical usage
For practical usage of `DataLoader`, please see [here](https://hackmd.io/@-CDCNK_qTUicXsissQsHMA/SJ6Gjpxv8).
## Reference
- [How to use 'collate_fn' with dataloaders?](https://stackoverflow.com/questions/65279115/how-to-use-collate-fn-with-dataloaders)