ML
DL
pytorch
Pytorch 可以把每一個 tensor 都放在不同的 device 裡,所以如果要寫一份 code 可以在不同的 device 上面跑,必須要先做規劃,不然要換 device 的時候會變得非常麻煩。
Pytorch doc 也有基本的介紹,這裡算是從不同角度切入來看,再從自己踩過的坑補充一些細節。
torch.ones([2,5], device='cuda')
torch.ones([2,5]).to('cuda')
torch.ones_like(tensor)
(會得到一個和 tensor
一樣 shape, dtype, device 的 tensor)model.to('cuda')
其中,model.to
的作法會把 model 裡的所有
全部都換到指定的 device。
其實這 3 種幾乎已經包含所有 model 中會用到的 tensor 了。底下會說明怎麼利用這些特性,才能用最簡單的 model.to
一次把所有需要的 tensor 都換到正確的 device。
如果在 model 裡需要紀錄自訂的 tensor,分為兩種做法:
self.my_param = nn.Parameter(tensor)
self.register_buffer('my_buffer', tensor)
self.my_buffer
拿到這個 tensormodel.state_dict()
裡看到這樣,這些 tensor 就都會在 model.to
的時候轉換到需要的 device。
需要特別注意的是 buffer 的部分,如果直接寫 self.my_buffer = tensor
,那這個 tensor 就不會被 model track,也就是在 model.state_dict()
裡不會拿到,同時,在 model.to
的時候也不會轉換 self.my_buffer
的 device。
如果在 model 裡訂一個 attribute self.my_module = my_module
,而且 my_module
是 nn.Module 的 instance(isinstance(my_module, nn.Module)
),那 my_module
就是 model 的 submodule。
所以說如果你剛好想要一個 list 的 modules 作為 attribute:
self.my_modules = [my_module_1, my_module_2]
,那 my_module_1
和 my_module_2
就不是 model 的 submodule,在 model.state_dict
和 model.parameters
都不會出現。self.my_modules = nn.ModuleList([my_module_1, my_module_2])
。同理,如果想要一個 dict 的 modules 就用 nn.ModuleDict
,如果要一個 list 的 parameters 就用 nn.ParameterList
,要一個 dict 的 paramters 就用 nn.ParameterDict
。
這樣這些 attributes 才會真正成為這個 model 的 submodule/parameter,支援各種功能,比如 model.parameters()
會 iterate over 所有 paramters 和 submodule 中的 parameters;model.state_dict()
會包含所有 submodule 中的 parameters 和 buffers,然後 model.to('cuda')
的時候也才會把全部 submodule 中的 tensors 都轉成 cuda tensor。
如果在 forward 時需要生成新的 tensor,比如 VAE 裡的 random noise
https://github.com/pytorch/examples/blob/5b1f45057dc14a5e2132b45233c258a1dc2a0aab/vae/main.py#L55
eps = torch.randn(*std.shape)
,那這個 eps 就會放在 cpu 裡,整個 model 就沒辦法在其他 device 上跑了。eps = torch.randn(*std.shape, device=device)
),但是這樣等於是要在 model.to
以外的地方另外 maintain 一個自己的 device。torch.randn_like
,這樣就可以確保這個 tensor 的 device 一定會和 input 相同。torch.randn(*std.shape, device=std.device)
Input/target tensor 做 device 轉換的時機視情況有所不同。
input.to(next(self.parameters()).device)
做轉換。這是終極解決方案,只要 model 設定好 device 任何 tensor input model 都會進到這個 device。但是似乎很少看到有人這樣做,不知道為什麼?