--- tags: PyTorch, Python --- # Pytorch 模型儲存與使用 ### 先建立一個模型 ```python import torch import torch.nn as nn class ExampleModel(nn.Module): def __init__(self, input_size): super(ExampleModel, self).__init__() self.linear = nn.Linear(input_size, 1) def forward(self, x): y_pred = torch.sigmoid(self.linear(x)) return y_pred model = ExampleModel(input_size=6) ``` --- :::info 常用的副檔名為: `.pt` 以及 `.pth` ::: ## 儲存模型 ### 1. 儲存整個模型和權重 ```python torch.save(model, './models/model.pt') ``` ### 2. 只儲存權重 ```python torch.save(model.state_dict(), './models/save.pt') ``` ## 載入模型 ### 1. 載入整個模型 ```python model = torch.load('./models/model.pt') model.eval() # 進入評估狀態 predict = model(data[0].numpy()) # 进行预测 ``` :::danger * 在運行推理之前,必須調用`model.eval()`,將設置`dropout` 和 `batch normalization` 層為<strong>評估模式</strong>。如果不這麼做,可能導致模型推斷結果不一致。 * 訓練前則是使用`model.train()`來進行訓練 ::: ### 2. 只載入權重 ```python model = ExampleModel(input_size=6) # 必須先創建模型 model.load_state_dict(torch.load('./models/save.pt')) model.eval() # 進入評估狀態 ``` ## 查看模型的參數 ```python for parm in model.parameters(): print(parm) ``` ## 保存和加載 Checkpoint 用於推理/繼續訓練 ### 1. 保存 ```python torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH) ``` ### 2. 載入 ```python model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train() ```