---
tags: PyTorch, Python
---
# Pytorch 模型儲存與使用
### 先建立一個模型
```python
import torch
import torch.nn as nn
class ExampleModel(nn.Module):
def __init__(self, input_size):
super(ExampleModel, self).__init__()
self.linear = nn.Linear(input_size, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = ExampleModel(input_size=6)
```
---
:::info
常用的副檔名為: `.pt` 以及 `.pth`
:::
## 儲存模型
### 1. 儲存整個模型和權重
```python
torch.save(model, './models/model.pt')
```
### 2. 只儲存權重
```python
torch.save(model.state_dict(), './models/save.pt')
```
## 載入模型
### 1. 載入整個模型
```python
model = torch.load('./models/model.pt')
model.eval() # 進入評估狀態
predict = model(data[0].numpy()) # 进行预测
```
:::danger
* 在運行推理之前,必須調用`model.eval()`,將設置`dropout` 和 `batch normalization` 層為<strong>評估模式</strong>。如果不這麼做,可能導致模型推斷結果不一致。
* 訓練前則是使用`model.train()`來進行訓練
:::
### 2. 只載入權重
```python
model = ExampleModel(input_size=6) # 必須先創建模型
model.load_state_dict(torch.load('./models/save.pt'))
model.eval() # 進入評估狀態
```
## 查看模型的參數
```python
for parm in model.parameters():
print(parm)
```
## 保存和加載 Checkpoint 用於推理/繼續訓練
### 1. 保存
```python
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
```
### 2. 載入
```python
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
```