在上一篇文章中,我們理解到GNN中的訊息傳遞機制,那在這一篇會透過 pytorch 實作訊息傳遞機制,還沒看過上篇的可以點以下連結:
【邁向圖神經網絡GNN】Part1: 圖數據的基本元素與應用
【邁向圖神經網絡GNN】Part2: 使用PyTorch構建圖形結構的概念與實作
【邁向圖神經網絡GNN】Part3: 圖神經網絡的核心-訊息傳遞機制
定義 message passing 的 class
定義一個 class 執行 message passing ,有兩個重要的元素:
- init : 在這裡, init 會定義 agg 使用的 function ,這裡選用 max
- forward : 當資料餵進去之後,會執行 propagate 函數,那這個函數會去呼叫 message 和 update
- 定義 message function:如同上篇的範例, 0.5 * 自己 + 2 *鄰居
- 定義 update function : 也是同上篇,1倍自己 + 0.5 倍 message
class self_designed_MessagePassingLayer(MessagePassing):
def __init__(self, aggr='max'):
super(self_designed_MessagePassingLayer, self).__init__(aggr)
self.aggr = aggr
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j):
return 0.5 * x_i + 2 * x_j
def update(self, aggr_out, x):
return x + 0.5 * aggr_out
實作 message passing class
# define message passing layer
self_designed_mp_layer = self_designed_MessagePassingLayer(aggr='max')
我們做一回合的 message passing
# Go through 1 message passing:
graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
graph.x = self_designed_mp_layer(graph.x, graph.edge_index)
print(f"After 1 mp layer, graph.x = \n{graph.x}\n")
輸入這個 class 的資料是:
- graph.x: node 的 feature
- graph.edge_index : 哪些節點相連
input 資料餵進去後,會去呼叫 forward ,再去執行 message 和 update
產出一次傳遞的結果:
After 1 mp layer, graph.x =
tensor([[12.5000, 8.0000],
[ 6.0000, 5.2500],
[12.2500, 7.7500],
[ 6.2500, 5.5000]])
以 node 0 為例,推導過程:
- from node 1 = 0.5*(6,4) + 2*(0,1) = (3,4)
- from node 2 = 0.5*(6,4) +2*(5,3) = (13,8)
以 max 取得 agg_out ,所以選擇 (13,8)
再做 node update
(6,4) + 0.5*(13,8) = (12.5, 8)
一回合的 message passing 傳遞一步距離的鄰居,十回合的 message passing 傳遞十步距離的鄰居
我們做十回合的 message passing
# Go through 10 message passing:
graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
for i in range(10):
graph.x = self_designed_mp_layer(graph.x, graph.edge_index)
print(f"After 10 mp layer, graph.x = \n{graph.x}")
輸出的結果
After 10 mp layer, graph.x =
tensor([[1839130.8750, 1406359.0000],
[1839130.8750, 1406359.0000],
[2298917.0000, 1758024.6250],
[ 919569.5000, 703268.6875]])
在這裡會發現目前使用到的 message passing 機制,僅用到 node feature ,並未加上 edge 的 feature ,所以其實還有很多更進階的 message passing 的 define。
各種不同的 message passing 機制
在 torch 的官網上,還有很多不同的 GNN ,主要的差別在於 message passing 的機制,最經常拿來使用與比較的是 GCNcov 的算法,我們也來實作看看~
特殊的 message passing 機制,讓 nn 自定義 message 和 update function
上述的基本範例是採用自定義的 message passing 機制,那進階與改良版之一,則是讓 neural network 自己學最適的機制。
class NN_MessagePassingLayer(MessagePassing):
def __init__(self, input_dim, hidden_dim, output_dim, aggr='mean'):
super(NN_MessagePassingLayer, self).__init__()
self.aggr = aggr
self.messageNN = nn.Linear(input_dim * 2, hidden_dim)
self.updateNN = nn.Linear(input_dim + hidden_dim, output_dim)
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x, messageNN=self.messageNN, updateNN=self.updateNN)
def message(self, x_i, x_j, messageNN):
return messageNN(torch.cat((x_i, x_j), dim=-1))
def update(self, aggr_out, x, updateNN):
return updateNN(torch.cat((x, aggr_out), dim=-1))
1. __init__
方法
input 包含:
input_dim
:節點特徵的維度。hidden_dim
:隱藏層的維度。output_dim
:輸出特徵的維度。aggr
:聚合函數,default 為 'mean',可以根據需要改為 'max', 'add' 等。
在這個方法中,先 init ,並設定 agg ,再 create 兩個神經網路層 — message NN + update NN
messageNN
:一個線性轉換層,用於將連接的節點對的特徵(x_i
和x_j
)轉換成隱藏層表示。其輸入維度是兩個節點特徵維度的總和。updateNN
:另一個線性轉換層,用於更新節點特徵。輸入為原始節點特徵和聚合後的訊息特徵,輸出為新的節點特徵。
2. forward
方法
這是NN的前向傳播方法,負責調用 propagate
方法進行訊息的傳遞。它接收節點特徵 x
和邊索引 edge_index
。propagate
方法是 MessagePassing
中定義的,用於處理訊息的生成、聚合和更新。
3. message
方法
定義了如何生成訊息。它接收來自兩個相連節點的特徵(x_i
和 x_j
),並使用 messageNN
將它們拼接後轉換成一個訊息。這個訊息隨後將被聚合到相應的節點上。
4. update
方法
最後,update
方法定義了如何根據聚合的訊息更新每個節點的特徵。它接收聚合後的訊息 aggr_out
和原始節點特徵 x
,使用 updateNN
將它們拼接後進行轉換,生成最終的節點特徵。
小結
今天探討了圖神經網絡(GNN)中的message passing 機制,並通過PyTorch實現了具體的程式碼實作。我們先回顧了其核心概念,然後進一步實作如何在PyTorch中定義和實現一個自定義的message passing layer。通過對單次和多次傳遞過程的結果,可以看到節點特徵如何逐步傳遞與更新。
後半部主要說明不同的 message passing 機制,像是最經典的加入 nn 變體,對於 message passing 機制的改良有一點想像,那以上是本篇的內容,下一篇見~