您好,登錄后才能下訂單哦!
本篇內容主要講解“怎么使用Pytorch Geometric實現GraphSAGE模型”,感興趣的朋友不妨來看看。本文介紹的方法操作簡單快捷,實用性強。下面就讓小編來帶大家學習“怎么使用Pytorch Geometric實現GraphSAGE模型”吧!
在使用GraphSAGE對節點進行嵌入學習之前,我們需要先將原始數據轉換為圖結構,并將其存儲為Pytorch Tensor格式。例如,我們可以使用networkx庫來構建一個簡單的圖:
import networkx as nx G = nx.karate_club_graph()
然后,我們可以使用Pytorch Geometric庫將NetworkX圖轉換為Pytorch Tensor格式。首先,我們需要安裝Pytorch Geometric并導入所需的類:
!pip install torch-geometric from torch_geometric.datasets import Planetoid from torch_geometric.transforms import NormalizeFeatures from torch_geometric.utils.convert import from_networkx
接著,我們可以使用from_networkx
函數將NetworkX圖轉換為Pytorch Tensor格式:
data = from_networkx(G)
此時,data
對象包含了關于節點、邊及其屬性的信息,例如:
data.edge_index: 2x(#edges)的長整型張量,表示邊的起點和終點
data.x
: n×dn \times dn×d 的浮點型張量,表示每個節點的特征向量(其中nnn是節點數量,ddd是特征維度)
注意,此時的data
對象并未包含鄰居信息。接下來,我們將介紹如何使用Sampler方法采樣節點鄰居。
GraphSAGE使用Sampler方法來聚合鄰居信息。在Pytorch Geometric中,可以使用Various Sampling方法來實現Sampler。例如,使用ClusterData方法將圖分成多個子圖,然后對每個子圖進行采樣操作。
以下是ClusterData
的使用示例:
from torch_geometric.utils import degree, to_undirected from torch_geometric.transforms import ClusterData # Convert the graph to an undirected graph, so we can aggregate neighbors in both directions. G = to_undirected(G) # Compute the degree of each node. deg = degree(data.edge_index[0], num_nodes=data.num_nodes) # Use METIS algorithm to partition the graph into multiple subgraphs. cluster_data = ClusterData(data, num_parts=2, recursive=False, transform=NormalizeFeatures(), degree=deg)
這里我們將原始圖分成兩個子圖,并對每個子圖進行規范化特征轉換。注意,在使用ClusterData方法之前,需要將原始圖轉換為無向圖。
另一個常用的Sampler方法是在隨機游動時對鄰居進行采樣,這種方法被稱為隨機游走采樣(Random Walk Sampling)。以下是隨機游走采樣的示例代碼:
from torch_geometric.utils import random_walk # Perform random walk sampling to obtain node neighbor samples. walk_length = 20 # The length of random walk trail. num_steps = 4 # The number of nodes to sample from each step. data.batch = None data.edge_index = to_undirected(data.edge_index) # Use undirected edge for random walk. rw_data = random_walk(data.edge_index, walk_length=walk_length, num_steps=num_steps)
這里我們將使用一個長度為20、每個步驟采樣4個鄰居的隨機游走方法。注意,在使用隨機游走方法進行采樣之前,需要使用無向邊。
GraphSAGE模型包含3個部分:1)圖卷積層;2)聚合器(Aggregator);3)輸出層。我們將在本節中介紹如何使用Pytorch實現這些組件。
首先,讓我們定義一個圖卷積層。圖卷積層的輸入是節點特征矩陣、鄰接矩陣和聚合器,輸出是新的節點特征矩陣。以下是圖卷積層的代碼實現:
import torch.nn.functional as F from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn import global_mean_pool class GraphSageConv(MessagePassing): def __init__(self, in_channels, out_channels, aggr='mean'): super(GraphSageConv, self).__init__(aggr=aggr) self.lin = nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): return self.propagate(edge_index, x=x) def message(self, x_j): return x_j def update(self, aggr_out, x): return F.relu(self.lin(torch.cat([x, aggr_out], dim=1)))
這里我們繼承了MessagePassing
類,并在__init__
函數中定義了一個全連接層,用于將輸入特征矩陣x
從 dind_{in}din 維映射到 doutd_{out}dout 維。在forward
函數中,我們使用propagate
方法來實現消息傳遞操作;在message
函數中,我們僅向下游節點發送原始特征數據;在update
函數中,我們首先對聚合結果進行ReLU非線性變換,然后再通過全連接層進行節點特征的更新。
接下來,讓我們定義一個聚合器。聚合器的輸入是采樣得到的鄰居特征矩陣,輸出是新的節點嵌入向量。以下是聚合器的代碼實現:
class MeanAggregator(nn.Module): def __init__(self, input_dim, output_dim): super(MeanAggregator, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.lin = nn.Linear(input_dim, output_dim) def forward(self, neigh_mean): out = F.relu(self.lin(neigh_mean)) return out
這里我們定義了一個簡單的均值聚合器,其將鄰居特征矩陣中每列的均值作為節點嵌入向量,并使用全連接層進行維度變換。
最后,讓我們定義整個GraphSage模型。GraphSage模型包含2個圖卷積層和1個輸出層。以下是模型的代碼實現:
class GraphSAGE(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2): super(GraphSAGE, self).__init__() self.conv1 = GraphSageConv(in_channels, hidden_channels) self.aggreg1 = MeanAggregator(hidden_channels, hidden_channels) self.conv2 = GraphSageConv(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index) x = global_mean_pool(x, edge_index) # Compute global mean over nodes. x = self.aggreg1(x) x = self.conv2(x, edge_index) return x
這里我們定義了一個包含2層GraphSAGE Conv層的神經網絡。在最后一層GraphSAGE Conv層之后,我們使用global_mean_pool
函數來計算節點嵌入的全局平均值。注意,在本示例中,我們僅保留了一個輸出節點,因此輸出矩陣的大小為1。如果需要輸出多個節點,則需要設置global_mean_pool
函數中的參數。
在定義好模型后,我們可以使用Pytorch進行模型訓練和測試。首先,讓我們定義一個損失函數和優化器:
criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
這里我們使用交叉熵作為損失函數,并使用Adam優化器來更新模型參數。
接著,我們可以開始訓練模型。以下是訓練過程的代碼實現:
num_epochs = 100 for epoch in range(num_epochs): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = criterion(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() print('Epoch {:03d}, Loss: {:.4f}'.format(epoch, loss.item()))
這里我們遍歷所有數據樣本,計算預測結果和真實標簽之間的交叉熵損失,并使用反向傳播來更新權重。我們在每個epoch結束后打印出當前損失值。
最后,我們可以對模型進行測試。以下是測試過程的代碼實現:
model.eval() with torch.no_grad(): pred = model(data.x, data.edge_index) pred = pred.argmax(dim=1) acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item() print('Test accuracy: {:.4f}'.format(acc))
這里我們使用測試集來計算模型的準確率。注意,在執行model.eval()
后,我們需要使用torch.no_grad()
包裝代碼塊,以禁止梯度計算。
到此,相信大家對“怎么使用Pytorch Geometric實現GraphSAGE模型”有了更深的了解,不妨來實際操作一番吧!這里是億速云網站,更多相關內容可以進入相關頻道進行查詢,關注我們,繼續學習!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。