【邁向圖神經網絡GNN】Part4: 實作圖神經網路訊息傳遞機制

Karen
8 min readJul 24, 2024

--

在上一篇文章中,我們理解到GNN中的訊息傳遞機制,那在這一篇會透過 pytorch 實作訊息傳遞機制,還沒看過上篇的可以點以下連結:

【邁向圖神經網絡GNN】Part1: 圖數據的基本元素與應用

【邁向圖神經網絡GNN】Part2: 使用PyTorch構建圖形結構的概念與實作

【邁向圖神經網絡GNN】Part3: 圖神經網絡的核心-訊息傳遞機制

Photo by Pedro Henrique Santos on Unsplash

定義 message passing 的 class

定義一個 class 執行 message passing ,有兩個重要的元素:

  1. init : 在這裡, init 會定義 agg 使用的 function ,這裡選用 max
  2. forward : 當資料餵進去之後,會執行 propagate 函數,那這個函數會去呼叫 message 和 update
  • 定義 message function:如同上篇的範例, 0.5 * 自己 + 2 *鄰居
  • 定義 update function : 也是同上篇,1倍自己 + 0.5 倍 message
class self_designed_MessagePassingLayer(MessagePassing):
def __init__(self, aggr='max'):
super(self_designed_MessagePassingLayer, self).__init__(aggr)
self.aggr = aggr

def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)

def message(self, x_i, x_j):
return 0.5 * x_i + 2 * x_j

def update(self, aggr_out, x):
return x + 0.5 * aggr_out

實作 message passing class

# define message passing layer
self_designed_mp_layer = self_designed_MessagePassingLayer(aggr='max')

我們做一回合的 message passing

# Go through 1 message passing:
graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
graph.x = self_designed_mp_layer(graph.x, graph.edge_index)
print(f"After 1 mp layer, graph.x = \n{graph.x}\n")

輸入這個 class 的資料是:

  • graph.x: node 的 feature
  • graph.edge_index : 哪些節點相連

input 資料餵進去後,會去呼叫 forward ,再去執行 message 和 update

產出一次傳遞的結果:

After 1 mp layer, graph.x =  
tensor([[12.5000, 8.0000],
[ 6.0000, 5.2500],
[12.2500, 7.7500],
[ 6.2500, 5.5000]])

以 node 0 為例,推導過程:

  • from node 1 = 0.5*(6,4) + 2*(0,1) = (3,4)
  • from node 2 = 0.5*(6,4) +2*(5,3) = (13,8)

以 max 取得 agg_out ,所以選擇 (13,8)

再做 node update

(6,4) + 0.5*(13,8) = (12.5, 8)

一回合的 message passing 傳遞一步距離的鄰居,十回合的 message passing 傳遞十步距離的鄰居

我們做十回合的 message passing

# Go through 10 message passing:
graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
for i in range(10):
graph.x = self_designed_mp_layer(graph.x, graph.edge_index)
print(f"After 10 mp layer, graph.x = \n{graph.x}")

輸出的結果

After 10 mp layer, graph.x =  
tensor([[1839130.8750, 1406359.0000],
[1839130.8750, 1406359.0000],
[2298917.0000, 1758024.6250],
[ 919569.5000, 703268.6875]])

在這裡會發現目前使用到的 message passing 機制,僅用到 node feature ,並未加上 edge 的 feature ,所以其實還有很多更進階的 message passing 的 define。

各種不同的 message passing 機制

https://pytorch-geometric.readthedocs.io/en/1.7.2/modules/nn.html

在 torch 的官網上,還有很多不同的 GNN ,主要的差別在於 message passing 的機制,最經常拿來使用與比較的是 GCNcov 的算法,我們也來實作看看~

特殊的 message passing 機制,讓 nn 自定義 message 和 update function

上述的基本範例是採用自定義的 message passing 機制,那進階與改良版之一,則是讓 neural network 自己學最適的機制。

class NN_MessagePassingLayer(MessagePassing):
def __init__(self, input_dim, hidden_dim, output_dim, aggr='mean'):
super(NN_MessagePassingLayer, self).__init__()
self.aggr = aggr

self.messageNN = nn.Linear(input_dim * 2, hidden_dim)
self.updateNN = nn.Linear(input_dim + hidden_dim, output_dim)

def forward(self, x, edge_index):
return self.propagate(edge_index, x=x, messageNN=self.messageNN, updateNN=self.updateNN)

def message(self, x_i, x_j, messageNN):
return messageNN(torch.cat((x_i, x_j), dim=-1))

def update(self, aggr_out, x, updateNN):
return updateNN(torch.cat((x, aggr_out), dim=-1))

1. __init__ 方法

input 包含:

  • input_dim:節點特徵的維度。
  • hidden_dim:隱藏層的維度。
  • output_dim:輸出特徵的維度。
  • aggr:聚合函數,default 為 'mean',可以根據需要改為 'max', 'add' 等。

在這個方法中,先 init ,並設定 agg ,再 create 兩個神經網路層 — message NN + update NN

  • messageNN:一個線性轉換層,用於將連接的節點對的特徵(x_ix_j)轉換成隱藏層表示。其輸入維度是兩個節點特徵維度的總和。
  • updateNN:另一個線性轉換層,用於更新節點特徵。輸入為原始節點特徵和聚合後的訊息特徵,輸出為新的節點特徵。

2. forward 方法

這是NN的前向傳播方法,負責調用 propagate 方法進行訊息的傳遞。它接收節點特徵 x 和邊索引 edge_indexpropagate 方法是 MessagePassing 中定義的,用於處理訊息的生成、聚合和更新。

3. message 方法

定義了如何生成訊息。它接收來自兩個相連節點的特徵(x_ix_j),並使用 messageNN 將它們拼接後轉換成一個訊息。這個訊息隨後將被聚合到相應的節點上。

4. update 方法

最後,update 方法定義了如何根據聚合的訊息更新每個節點的特徵。它接收聚合後的訊息 aggr_out 和原始節點特徵 x,使用 updateNN 將它們拼接後進行轉換,生成最終的節點特徵。

小結

今天探討了圖神經網絡(GNN)中的message passing 機制,並通過PyTorch實現了具體的程式碼實作。我們先回顧了其核心概念,然後進一步實作如何在PyTorch中定義和實現一個自定義的message passing layer。通過對單次和多次傳遞過程的結果,可以看到節點特徵如何逐步傳遞與更新。

後半部主要說明不同的 message passing 機制,像是最經典的加入 nn 變體,對於 message passing 機制的改良有一點想像,那以上是本篇的內容,下一篇見~

--

--

Karen

分享一些自學的程式筆記,紀錄自己的成長歷程軌跡