###### tags: `PyTorch` # PyTorch - 搭建神經網絡 在 PyTorch 中搭建神經網絡有一個固定格式,[參考PyTorch 文本](https://pytorch.org/docs/stable/nn.html) 格式固定為 class Model,其中包含`__init__(self)` & `forward(self, x)` 如下: ```python= import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) ``` 來試試上一篇回歸的問題,使用這種方式搭建模型。 ## 1. 建立 DATA ```python= import torch import matplotlib.pyplot as plt X = torch.unsqueeze(torch.linspace(-1, 1, 200), dim=1) # x data (tensor), shape=(100, 1) Y = 2*X.pow(2) + 0.3*torch.rand(X.size()) # noisy y data (tensor), shape=(100, 1) plt.scatter(X.data.numpy(), Y.data.numpy()) plt.show() ``` ![](https://i.imgur.com/sHQVI8I.png) 將X,Y 丟進 Variable 做梯度計算用。 ```python= X=Variable(torch.Tensor(X.reshape(200,1))) Y=Variable(torch.Tensor(Y.reshape(200,1))) ``` ## 2. 建立神經網絡模型 先定義所有的層屬性(__init __()),然後再一層層搭建(forward(x))層於層的關係鏈接。 這邊先在 `__init__()` 中寫好要使用的 nn,然後在 forward()中將串接過程使用激活函數。 ```python= import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.nn1 = nn.Linear(1, 15) #第一層 Linear NN self.nn2 = nn.Linear(15, 1) #第二層 Linear NN def forward(self, x): x = F.relu(self.nn1(x)) #對第一層 NN 使用Relu激活 x = self.nn2(x) #第二層直接輸出 return x model = Model() print(model) #將模型print出來看看 ``` ``` Model( (nn1): Linear(in_features=1, out_features=15, bias=True) (nn2): Linear(in_features=15, out_features=1, bias=True) ) ``` 搭建好了之後,一樣選擇優化器&損失函數: ```python= optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005) loss_function = torch.nn.MSELoss() ``` ## 3. 直接將訓練過程可視化 ```python= epochs = 500 for epoch in range(epochs): prediction = model(X) loss = loss_function(prediction, Y) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 20 == 0: # plot and show learning process plt.cla() plt.scatter(X.data.numpy(), Y.data.numpy()) plt.plot(X.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'}) plt.savefig('D:\\img'+'%s'%epoch+'.jpg') plt.pause(0.1) ``` 製作成GIF如下: ![](https://i.imgur.com/clVtNGs.gif)