您好,登錄后才能下訂單哦!
本文主要是用PyTorch來實現一個簡單的回歸任務。
編輯器:spyder
1.引入相應的包及生成偽數據
import torch import torch.nn.functional as F # 主要實現激活函數 import matplotlib.pyplot as plt # 繪圖的工具 from torch.autograd import Variable # 生成偽數據 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1) y = x.pow(2) + 0.2 * torch.rand(x.size()) # 變為Variable x, y = Variable(x), Variable(y)
其中torch.linspace
是為了生成連續間斷的數據,第一個參數表示起點,第二個參數表示終點,第三個參數表示將這個區間分成平均幾份,即生成幾個數據。因為torch只能處理二維的數據,所以我們用torch.unsqueeze
給偽數據添加一個維度,dim表示添加在第幾維。torch.rand
返回的是[0,1)之間的均勻分布。
2.繪制數據圖像
在上述代碼后面加下面的代碼,然后運行可得偽數據的圖形化表示:
# 繪制數據圖像 plt.scatter(x.data.numpy(), y.data.numpy()) plt.show()
3.建立神經網絡
class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer self.predict = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x): x = F.relu(self.hidden(x)) # activation function for hidden layer x = self.predict(x) # linear output return x net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network print(net) # net architecture
一般神經網絡的類都繼承自torch.nn.Module
,__init__()和forward()
兩個函數是自定義類的主要函數。在__init__()
中都要添加一句super(Net, self).__init__(),
這是固定的標準寫法,用于繼承父類的初始化函數。__init__()
中只是對神經網絡的模塊進行了聲明,真正的搭建是在forwad()
中實現。自定義類中的成員都通過self指針來進行訪問,所以參數列表中都包含了self。
如果想查看網絡結構,可以用print()
函數直接打印網絡。本文的網絡結構輸出如下:
Net ( (hidden): Linear (1 -> 10) (predict): Linear (10 -> 1) )
4.訓練網絡
# 訓練100次 for t in range(100): prediction = net(x) # input x and predict based on x loss = loss_func(prediction, y) # 一定要是輸出在前,標簽在后 (1. nn output, 2. target) optimizer.zero_grad() # clear gradients for next train loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients
訓練網絡之前我們需要先定義優化器和損失函數。torch.optim
包中包括了各種優化器,這里我們選用最常見的SGD作為優化器。因為我們要對網絡的參數進行優化,所以我們要把網絡的參數net.parameters()
傳入優化器中,并設置學習率(一般小于1)。
由于這里是回歸任務,我們選擇torch.nn.MSELoss()
作為損失函數。
由于優化器是基于梯度來優化參數的,并且梯度會保存在其中。所以在每次優化前要通過optimizer.zero_grad()
把梯度置零,然后再后向傳播及更新。
5.可視化訓練過程
plt.ion() # something about plotting for t in range(100): ... if t % 5 == 0: # plot and show learning process plt.cla() plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'}) plt.pause(0.1) plt.ioff() plt.show()
6.運行結果
以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。