## Prerequisites python basic concept - magic method - zip - iterable, iterator, generator ## Introduction When training DL models, it is often needed to preprocess data and find a way to feed data into model. The deep learning library `torch` provides two utility class `torch.utils.Dataset` and `torch.utils.DataLoader` to facilitate preprocessing. ## Examples A typical usage is as follows 0. import ```python= import torch from torch.utils.data import Dataset, DataLoader ``` 1. Define a custom dataset class and instantiate one object ```python= class MyDataSet(Dataset): def __init__(self, x, y) -> None: super().__init__() self.x = x self.y = y def __getitem__(self, index): return self.x[index], self.y[index] def __len__(self) -> int: return len(self.x) # batch_size, length, class B, L, C = 3, 4, 5 x = torch.randn((B, L, C)) y = torch.randn((B)) ds = MyDataSet(x, y) ``` 2. Instantiate a DataLoader object ```python= dl = DataLoader(ds, batch_size=2) ``` 3. feed the data into the model ```python= for batch in dl: x, y = batch output = model(x) ... ``` ## Custom Dataset We will usually build our custom dataset inherited from `torch.utils.Dataset`, for which we only need to implement three magic methods in our custom class. - `__init__()`: store raw data - `__getitem__()`: implements indexing - `__len__()`: implements `len()` ## DataLoader `DataLoader` is a generator that takes a `torch.utils.Dataset` as input when constructing. It has some important parameters when constructing - `batch_size`: How many data for this batch - `shuffle`: shuffle the data or not - `pin_memory`: pin data inside memory to prevent swapped out, usaully set `True` - `collate_fn` ### How DataLoader prepare a batch When iterating, `DataLoader` first collects data from `Dataset` using `__getitem__()` and convert them to `torch.Tensor`, which obtains **a list of tuple**. (A datapoint is a tuple). Finally the list is fed into `collate_fn()` to produce the output. Let's first see a DataLoader without using `collate_fn()` ```python= dl_no_collate = DataLoader(ds, batch_size=3, collate_fn=lambda x: x) raw_batch = next(iter(dl_no_collate)) raw_batch ``` The output is of the form `[(x1, y1), (x2, y2), (x3, y3)]`, a list of tuples. ```python= [(tensor([[ 1.0475, 0.0485, -0.2177, 1.4522, -0.5661], [ 0.3699, 1.0205, 0.8040, 0.2916, 0.0446], [-0.4648, 3.2485, 0.2894, 0.4325, -0.5360], [-0.7087, 1.1254, 0.0747, -1.0728, -0.7503]]), tensor(0.9723)), (tensor([[ 0.0565, -0.8368, 0.9353, -0.9156, -0.4951], [ 0.3750, 0.5120, -0.9690, 0.4152, -1.1938], [ 0.9016, 0.0136, -1.0142, 1.8649, -0.1401], [-0.5947, 0.0313, -0.6279, 1.5463, -0.6546]]), tensor(0.8549)), (tensor([[-0.9101, -0.2072, -0.0353, -0.5205, -1.2083], [-1.6441, -1.0244, 0.1473, -0.3761, -1.0025], [-1.2569, -0.4252, 0.1050, 1.4469, 0.4834], [ 0.2454, 0.4732, -0.2637, 0.8446, -1.0841]]), tensor(1.0406))] ``` The default `collate_fn()` collects same attributes from the list and stack them using `torch.stack()`. The following code does the exact same as default `collate_fn()` ```python= def default_collate_fn(data): ''' data: [(x1, y1), (x2, y2), (x3, y3)] attributes: [(x1, x2, x3), (y1, y2, y3)] ''' attributes = list(zip(*data)) ret = [] for att in attributes: ret.append(torch.stack(att)) return ret ``` ```python= dl = DataLoader(ds, batch_size=3, collate_fn=default_collate_fn) batch = next(iter(dl)) batch ``` The result batch is of the form `(x, y)` ```python= [tensor([[[ 1.0475, 0.0485, -0.2177, 1.4522, -0.5661], [ 0.3699, 1.0205, 0.8040, 0.2916, 0.0446], [-0.4648, 3.2485, 0.2894, 0.4325, -0.5360], [-0.7087, 1.1254, 0.0747, -1.0728, -0.7503]], [[ 0.0565, -0.8368, 0.9353, -0.9156, -0.4951], [ 0.3750, 0.5120, -0.9690, 0.4152, -1.1938], [ 0.9016, 0.0136, -1.0142, 1.8649, -0.1401], [-0.5947, 0.0313, -0.6279, 1.5463, -0.6546]], [[-0.9101, -0.2072, -0.0353, -0.5205, -1.2083], [-1.6441, -1.0244, 0.1473, -0.3761, -1.0025], [-1.2569, -0.4252, 0.1050, 1.4469, 0.4834], [ 0.2454, 0.4732, -0.2637, 0.8446, -1.0841]]]), tensor([0.9723, 0.8549, 1.0406])] ``` ### Pratical usage For practical usage of `DataLoader`, please see [here](https://hackmd.io/@-CDCNK_qTUicXsissQsHMA/SJ6Gjpxv8). ## Reference - [How to use 'collate_fn' with dataloaders?](https://stackoverflow.com/questions/65279115/how-to-use-collate-fn-with-dataloaders)