# 資料流
[TOC]
## 簡介
$\qquad$在大數據的時代,資料的大小動輒是GB或是TB等級,這樣的資料不可能一次性的被儲存到記憶體之上並提供給機器學習模型進行訓練。這種時候我們需要做事情就是讓資料從硬碟上直接送到GPU本身進行訓練,這樣就能避免記憶體大小對我們的資料造成的限制。在[利用TensorFlow以及PyTorch建立張量](/e3tpZ2mqQwiAqEV9nmZUBQ)章節中,我們簡單示範了如何利用生成器(`generator`)物件實作資料流。在本章節中,我們將示範如何建立一個基於`torch.utils.data.Dataset`物件的資料流。有人可能會問「都有基於生成器的作法了,為何還需要基於前述物件的實作呢?」,原因如下:
1. 當取用資料時,擁有更安全的多進程(Multiprocessing)及多線程(Multi-threading)過程
$\qquad$在利用生成器的實作範例中,生成器只是根據既有的規則以及設定,不斷的從硬碟終將資料送模型。若要實作打散資料或是在每個訓練週期中對資料進行重新排列,使用者往往會需要付出更大的心力進行編寫。同時,因為其資料流已經規則已經在建構初期救定下來,若非使用者在一開始就實作了多線程/多進程,將無法在訓練過程中透過其他方式改善資料流的效率。
3. 更靈活的資料操作
$\qquad$一個繼承或`torch.utils.data.Dataset`物件的資料流實作,將可以繼承這兩個物件本身已經預先設計好的功能(例如:`on_epoch_end`等功能),使用者無須再花時間及精力來實作一些進階功能。
## PyTorch實作
$\qquad$在PyTorch中的實作,基本上與在TensorFlow無異。然而在PyTorch中,特別將實作分成了兩種類別,分別為:**Map-style datasets**以及**Iterable-style datasets**。其中,**Map-style datasets**的實作與TensorFLow的實作沒有太多差異,唯一差異在PyTorch的版本不需要實作`batch_size`的取用,這部份會由`torch.utils.data.DataLoader`來完成。以下是在PyTorch中,**Map-style datasets**以及**Iterable-style datasets**的差異。
### 說明及實作範例
1. Map-style datasets
$\qquad$一個Map-style datasets指的是實作了`__getitem__()` 以及`__len__()`方法的資料集。此物件以及資料流的實作需要繼承`torch.utils.data.Dataset`類別。範例如下:
```python=
import torch
import glob
import cv2
import torch.nn.functional as F
class img_pipeline(torch.utils.data.Dataset):
def __init__(self, path, mode, transform=None):
self.path = path
self.mode = mode
self.img_list = sorted(glob.glob(self.path+f"{mode}/"+"*.jpg"))
self.label_list = sorted(glob.glob(self.path+f"{mode}/"+"*.txt"))
self.transform = transform
def __len__(self):
return len(self.img_list)
def __getitem__(self, index):
if self.transform is not None:
_img = cv2.imread(self.img_list[index])
_img = self.transform(_img)
else:
_img = cv2.imread(self.img_list[index])
with open( self.label_list[index], 'r')as f:
_label = torch.tensor(int(f.readlines()[0]))
return _img, F.one_hot(_label, num_classes=800)
```
2. Iterable-style datasets
$\qquad$一個Iterable-style datasets指的是實作了`__iter__()`方法的資料集。此物件以及資料流的實作需要繼承`torch.utils.data.IterableDataset`類別。此物件的特點是提供了一個**可迭代(Iterable)** 的資料集操作。資料的取用方式不是透過`[]`以及相對索引的搭配來取用,而是透過對一個可迭代對象[^1]的迭代取值來實現。此方式在處理串流媒體資料時非常有用。
```python=
class img_pipeline_itertype(torch.utils.data.IterableDataset):
def __init__(self, path, mode, start, end, transform=None):
self.path = path
self.mode = mode
self.img_list = glob.glob(self.path+f"{mode}/"+"*.jpg")[:10]
self.label_list = sorted(glob.glob(self.path+f"{mode}/"+"*.txt"))[:10]
self.transform = transform
assert end > start, "End should be larger than start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
img = []
label = []
if self.transform is not None:
for idx in range(iter_start, iter_end):
_img = cv2.imread(self.img_list[idx])
_img = self.transform(_img)
img.append(_img)
else:
for idx in range(iter_start, iter_end):
_img = cv2.imread(self.img_list[idx])
img.append(_img)
for idx in range(iter_start, iter_end):
with open( self.label_list[idx], 'r')as f:
_label = torch.tensor(int(f.readlines()[0]))
label.append(F.one_hot(_label, num_classes=800))
_data = [[a, b] for a, b in zip(img, label)]
return iter(_data)
```
### 包裝成`torch.utils.data.DataLoader`物件
$\qquad$在PyTorch中,不像TensorFlow中可以直接將繼承了`tf.keras.utils.Sequence`的物件送進`.fit()`函數進行訓練。在PyTorch中,需要先將物件包含成`torch.utils.data.DataLoader`物件,再將變數送進模型進行訓練。這樣做的必要性在:(一)`torch.utils.data.DataLoader`的包裝同時包含了對`batch_size`的定義,`sampler`的宣告以及多線程的實現。`sampler`是隨機打散資料集所需要的函數,而在包裝時宣告的`num_works`以及其他變數將讓資料的取用可以藉由多線程來進行。作法很簡單,只要如下方範例執行即可(以`Map-style datasets`中的物件為範例)[^2]:
```python=
train_dataset = torch.utils.data.DataLoader(
img_pipeline(PATH, 'train', transform=my_transform),
batch_size= 2,
num_workers=10)
```
## 結論
除去在PyTorch中個`Iterable-style datasets`,在TensorFlow與PyTorch之中對資料流的實作基本上差不多,讀者可以在符合自身需求下進行實作。
[^1]: 可迭代對象就是有實作`__iter__()`方法的變數,諸如`List`以及`numpy.ndrray`都屬於此類別。
[^2]: 詳細參數可在[官方說明書](https://pytorch.org/docs/stable/data.html#map-style-datasets)中查看。
###### tags: `Machine Learning` `Notebook` `技術隨筆` `機器學習` `Python` `PyTorch`