import torch
import torch.nn as nn
class ExampleModel(nn.Module):
def __init__(self, input_size):
super(ExampleModel, self).__init__()
self.linear = nn.Linear(input_size, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = ExampleModel(input_size=6)
常用的副檔名為: .pt
以及 .pth
torch.save(model, './models/model.pt')
torch.save(model.state_dict(), './models/save.pt')
model = torch.load('./models/model.pt')
model.eval() # 進入評估狀態
predict = model(data[0].numpy()) # 进行预测
model.eval()
,將設置dropout
和 batch normalization
層為評估模式。如果不這麼做,可能導致模型推斷結果不一致。model.train()
來進行訓練model = ExampleModel(input_size=6) # 必須先創建模型
model.load_state_dict(torch.load('./models/save.pt'))
model.eval() # 進入評估狀態
for parm in model.parameters():
print(parm)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
hackmd-github-sync-badge
Jun 19, 2024Read IR model Method 1 :::danger Problem: Recognize first image need more time (initalize time) ::: net.setPreferableBackend 參數
Sep 8, 2022讀取格式為 BGR 讀取圖片 參數 用法 cv2.IMREAD_COLOR 正常顏色
Aug 28, 2022源碼放在:Colab 1. 載入需要的 Module import torch import torch.nn as nn import torchvision import matplotlib.pyplot as plt import torch.utils.data as Data from torch.autograd import Variable import os
Aug 9, 2022or
By clicking below, you agree to our terms of service.
New to HackMD? Sign up