---
# System prepended metadata

title: 'Torch: Dataset, DataLoader'
tags: [DL]

---

## Prerequisites
python basic concept
- magic method
- zip
- iterable, iterator, generator


## Introduction
When training DL models, it is often needed to preprocess data and find a way to feed data into model. The deep learning library `torch` provides two utility class `torch.utils.Dataset` and `torch.utils.DataLoader` to facilitate preprocessing.

## Examples
A typical usage is as follows

0. import
```python=
import torch
from torch.utils.data import Dataset, DataLoader
```

1. Define a custom dataset class and instantiate one object
```python=
class MyDataSet(Dataset):
    def __init__(self, x, y) -> None:
        super().__init__()
        self.x = x
        self.y = y
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    
    def __len__(self) -> int:
        return len(self.x)

# batch_size, length, class
B, L, C = 3, 4, 5

x = torch.randn((B, L, C))
y = torch.randn((B))
ds = MyDataSet(x, y)
```

2. Instantiate a DataLoader object

```python=
dl = DataLoader(ds, batch_size=2)
```

3. feed the data into the model
```python=
for batch in dl:
    x, y = batch
    output = model(x)
    ...
```


## Custom Dataset

We will usually build our custom dataset inherited from `torch.utils.Dataset`, for which we only need to implement three magic methods in our custom class.
- `__init__()`: store raw data
- `__getitem__()`: implements indexing
- `__len__()`: implements `len()`

## DataLoader
`DataLoader` is a generator that takes a `torch.utils.Dataset` as input when constructing.

It has some important parameters when constructing
- `batch_size`: How many data for this batch
- `shuffle`: shuffle the data or not
- `pin_memory`: pin data inside memory to prevent swapped out, usaully set `True`
- `collate_fn`

### How DataLoader prepare a batch

When iterating, `DataLoader` first collects data from `Dataset` using `__getitem__()` and convert them to `torch.Tensor`, which obtains **a list of tuple**. (A datapoint is a tuple). Finally the list is fed into `collate_fn()` to produce the output.

Let's first see a DataLoader without using `collate_fn()`
```python=
dl_no_collate = DataLoader(ds, batch_size=3, collate_fn=lambda x: x)

raw_batch = next(iter(dl_no_collate))
raw_batch
```

The output is of the form `[(x1, y1), (x2, y2), (x3, y3)]`, a list of tuples.
```python=
[(tensor([[ 1.0475,  0.0485, -0.2177,  1.4522, -0.5661],
          [ 0.3699,  1.0205,  0.8040,  0.2916,  0.0446],
          [-0.4648,  3.2485,  0.2894,  0.4325, -0.5360],
          [-0.7087,  1.1254,  0.0747, -1.0728, -0.7503]]),
  tensor(0.9723)),
 
 (tensor([[ 0.0565, -0.8368,  0.9353, -0.9156, -0.4951],
          [ 0.3750,  0.5120, -0.9690,  0.4152, -1.1938],
          [ 0.9016,  0.0136, -1.0142,  1.8649, -0.1401],
          [-0.5947,  0.0313, -0.6279,  1.5463, -0.6546]]),
  tensor(0.8549)),
 
 (tensor([[-0.9101, -0.2072, -0.0353, -0.5205, -1.2083],
          [-1.6441, -1.0244,  0.1473, -0.3761, -1.0025],
          [-1.2569, -0.4252,  0.1050,  1.4469,  0.4834],
          [ 0.2454,  0.4732, -0.2637,  0.8446, -1.0841]]),
  tensor(1.0406))]
```

The default `collate_fn()` collects same attributes from the list and stack them using `torch.stack()`.

The following code does the exact same as default `collate_fn()`
```python=
def default_collate_fn(data):
    '''
    data: 
        [(x1, y1), 
         (x2, y2),
         (x3, y3)]
    attributes:
        [(x1, x2, x3),
         (y1, y2, y3)]
    '''
    attributes = list(zip(*data))

    ret = []
    for att in attributes:
        ret.append(torch.stack(att))

    return ret
```

```python=
dl = DataLoader(ds, batch_size=3, collate_fn=default_collate_fn)

batch = next(iter(dl))
batch
```

The result batch is of the form `(x, y)`
```python=
[tensor([[[ 1.0475,  0.0485, -0.2177,  1.4522, -0.5661],
          [ 0.3699,  1.0205,  0.8040,  0.2916,  0.0446],
          [-0.4648,  3.2485,  0.2894,  0.4325, -0.5360],
          [-0.7087,  1.1254,  0.0747, -1.0728, -0.7503]],
 
         [[ 0.0565, -0.8368,  0.9353, -0.9156, -0.4951],
          [ 0.3750,  0.5120, -0.9690,  0.4152, -1.1938],
          [ 0.9016,  0.0136, -1.0142,  1.8649, -0.1401],
          [-0.5947,  0.0313, -0.6279,  1.5463, -0.6546]],
 
         [[-0.9101, -0.2072, -0.0353, -0.5205, -1.2083],
          [-1.6441, -1.0244,  0.1473, -0.3761, -1.0025],
          [-1.2569, -0.4252,  0.1050,  1.4469,  0.4834],
          [ 0.2454,  0.4732, -0.2637,  0.8446, -1.0841]]]),
 
 tensor([0.9723, 0.8549, 1.0406])]
```




### Pratical usage
For practical usage of `DataLoader`, please see [here](https://hackmd.io/@-CDCNK_qTUicXsissQsHMA/SJ6Gjpxv8).

## Reference
- [How to use 'collate_fn' with dataloaders?](https://stackoverflow.com/questions/65279115/how-to-use-collate-fn-with-dataloaders)