在PyTorch中創建神經網絡模型通常需要定義一個繼承自torch.nn.Module
類的自定義類。下面是一個簡單的示例:
import torch
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128) # 定義一個全連接層
self.relu = nn.ReLU() # 定義一個激活函數
self.fc2 = nn.Linear(128, 10) # 定義另一個全連接層
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
在上面的示例中,我們定義了一個簡單的神經網絡模型SimpleNN
,包括兩個全連接層和一個ReLU激活函數。__init__
方法用于定義模型的結構,forward
方法用于定義模型的前向傳播過程。
要使用這個模型,可以實例化一個對象并傳入輸入數據進行前向傳播計算:
model = SimpleNN()
input_data = torch.randn(1, 784) # 創建一個輸入數據張量
output = model(input_data) # 進行前向傳播
print(output)
這樣就可以在PyTorch中創建一個簡單的神經網絡模型了。您可以根據自己的需求定義更復雜的模型結構和前向傳播過程。