Try   HackMD

Pytorch 模型儲存與使用

先建立一個模型

import torch
import torch.nn as nn

class ExampleModel(nn.Module):
    
    def __init__(self, input_size):
        super(ExampleModel, self).__init__()
        self.linear = nn.Linear(input_size, 1)
    
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = ExampleModel(input_size=6)

常用的副檔名為: .pt 以及 .pth

儲存模型

1. 儲存整個模型和權重

torch.save(model, './models/model.pt')

2. 只儲存權重

torch.save(model.state_dict(), './models/save.pt')

載入模型

1. 載入整個模型

model = torch.load('./models/model.pt')
model.eval() # 進入評估狀態 
predict = model(data[0].numpy()) # 进行预测
  • 在運行推理之前,必須調用model.eval(),將設置dropoutbatch normalization 層為評估模式。如果不這麼做,可能導致模型推斷結果不一致。
  • 訓練前則是使用model.train()來進行訓練

2. 只載入權重

model = ExampleModel(input_size=6) # 必須先創建模型
model.load_state_dict(torch.load('./models/save.pt'))
model.eval() # 進入評估狀態 

查看模型的參數

for parm in model.parameters():
    print(parm)

保存和加載 Checkpoint 用於推理/繼續訓練

1. 保存

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

2. 載入

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()