Try   HackMD

PyTorch 辨識 Cifar10

源碼放在:Colab

1. 載入需要的 Module

import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import torch.utils.data as Data
from torch.autograd import Variable
import os

2. 設置模型參數

DOANLOAD_DATASET = True
LR = 0.001
BATCH_SIZE=128
EPOCH = 10
MODELS_PATH = './models'

3. 數據預先處理的步驟

train_transform = torchvision.transforms.Compose([
    # torchvision.transforms.RandomCrop(32, 4),
    # torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

4. 載入Cifar10 Dataset

train_data = torchvision.datasets.CIFAR10(
    root='./cifar10',
    train=True,
    transform=train_transform,
    download=DOANLOAD_DATASET
)

test_data = torchvision.datasets.CIFAR10(
    root='./cifar10',
    train=False,
    transform=test_transform,
    download=DOANLOAD_DATASET
)

5. 把要訓練的資料放入Data.DataLoader

data_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

這樣能在訓練時一次讀取1個Batch_size的數據而不用讀取整個Daset的數據

6. Cifar10的類別名稱

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

7. 創建模型

class CNN(nn.Module):
  def __init__(self, num_classes: int):
    super(CNN, self).__init__()
    self.num_classes = num_classes

    # in[N, 3, 32, 32] => out[N, 16, 16, 16]
    self.conv1 = nn.Sequential(
        nn.Conv2d(
            in_channels=3,
            out_channels=16,
            kernel_size=5,
            stride=1,
            padding=2
        ),
        nn.ReLU(True),
        nn.MaxPool2d(kernel_size=2)
    )
    # in[N, 16, 16, 16] => out[N, 32, 8, 8]
    self.conv2 = nn.Sequential(
        nn.Conv2d(16, 32, 5, 1, 2),
        nn.ReLU(True),
        nn.MaxPool2d(2)

    )
    # in[N, 32 * 8 * 8] => out[N, 128]
    self.fc1 = nn.Sequential(
        nn.Linear(32 * 8 * 8, 128),
        nn.ReLU(True)
    )
    # in[N, 128] => out[N, 64]
    self.fc2 = nn.Sequential(
        nn.Linear(128, 64),
        nn.ReLU(True)
    )
    # in[N, 64] => out[N, 10]
    self.out = nn.Linear(64, self.num_classes)

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # [N, 32 * 8 * 8]
    x = self.fc1(x)
    x = self.fc2(x)
    output = self.out(x)
    return output

8. 使用模型

cnn = CNN(len(classes))

9. 使用梯度下降優化器

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)

10. 損失函數

loss_function = nn.CrossEntropyLoss()

11. 開始訓練模型

for epoch in range(EPOCH):
  cnn.train()
  for step, (x, y) in enumerate(data_loader):
    b_x = Variable(x, requires_grad=False)
    b_y = Variable(y, requires_grad=False)
    out = cnn(b_x)
    loss = loss_function(out, b_y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
      print('Epoch: {} | Step: {} | Loss: {}'.format(epoch + 1, step, loss))

12. 儲存模型

if not os.path.exists(MODELS_PATH):
  os.mkdir(MODELS_PATH)
torch.save(cnn, os.path.join(MODELS_PATH, 'cnn_model.pt'))

13. 創建測試集

test_loader = Data.DataLoader(
    dataset=test_data,
    batch_size=test_data.data.shape[0],
    shuffle=False
)
test_x, test_y = iter(test_loader).next()

14. 評估整個測試集的準確度

cnn.eval()
prediction = torch.argmax(cnn(test_x), 1)
acc = torch.eq(prediction, test_y)
print('Accuracy: {:.2%}'.format((torch.sum(acc) / acc.shape[0]).item()))