Message passing neural network (MPNN)는 graph neural network (GNN)의 가장 기본이 되는 프레임워크로, 노드의 이웃들의 정보를 이용해서 해당 노드의 상태를 업데이트하는 형태를 가지는 모든 neural network들을 말한다. 이번 글에서는 MPNN이 무엇인지에 대해 알아보겠다.
구조
MPNN은 기본적으로 정보를 aggregate하고 update하는 message passing phase와 이를 활용해서 결과를 도출하는 readout phase로 이루어진다.
message passing phase
message passing phase는 정보를 aggregate하는 역할을 하는 message function과 hidden state를 update하는 역할을 하는 update function으로 이루어진다.
message passing
Message function은 노드 v에 대한 정보를 얻기 위해 정보들을 aggregate하는 역할을 한다. 노드 v에 대한 message function은 다음과 같은 형태를 가진다.
위 식에서 $h_i$는 노드 i에 대한 hidden state, $N(v)$는 노드 v의 neighborhood, $e_{vw}$는 edge vw의 feature vector, $M_t$는 이 모든 것을 aggregate하는 message function이다. 즉, message function은 우리가 알아보고 싶은 노드의 현재 상태, 그 노드의 이웃들의 현재 상태, 그리고 그 노드와 노드의 이웃들을 연결하는 엣지들의 정보를 aggregate하여 우리가 알아보고 싶은 노드의 다음 message를 표현하는 것이다. 예를 들어, 다음과 같은 그래프가 있다고 하자.
이 그래프의 노란색 노드 A의 정보를 얻고 싶다면 어떻게 해야할까? 우선, 이웃한 노드인 B, C, D의 정보를 aggregate해서 message를 얻을 수 있을 것이다. 그렇다면 이 이웃한 노드 B, C, D의 정보는 어떻게 얻을 수 있을까? 또 그것 각각의 이웃 A, C / A, B, E, F / A의 정보를 aggregate해서 message를 얻을 수 있다. 이를 그림으로 표현하면 다음과 같다.
update function
Update function은 이렇게 얻어진 message를 활용해서 노드의 다음 hidden state를 update하는 역할을 한다. 노드 v에 대한 update function은 다음과 같은 형태를 가진다.
즉, 앞에서 얻었던 message function ($m_v^{t+1}$)과 함께 노드의 현재 hidden state ($h_v^t$)를 고려하여 노드의 다음 hidden state를 update하는 것이다.
readout phase
다음으로, readout phase에서는 이렇게 우리가 얻은 hidden state를 활용해서 우리가 예측하기를 원하는 노드의 label, 그래프의 label 등을 도출한다. 이를 나타내는 식은 다음과 같다.
즉, 위의 message passing phase를 T번 반복해서, 각 노드 v들에 대한 T번째 hidden state를 얻고, 이 hidden state를 readout function에 넣어서, 우리가 원하는 label을 도출하는 것이다.
장점
이러한 MPNN 구조가 가지는 장점은 그래프의 structural한 정보와 feature 정보를 모두 얻을 수 있다는 것이다. 어떤 노드의 embedding을 얻음으로써 우리는 그 노드의 k-hop 이웃들의 정보를 모두 반영할 수 있고, 이는 노드가 어떤 노드들과 연결되어 있는지의 structural한 정보를 반영할 수 있다고 할 수 있다. 또한 k-hop 이웃들의 feature들 또한 aggregate하여 반영함으로써 feature 정보까지 반영할 수 있다.
예시
이러한 MPNN 프레임워크를 따르는 GNN 모델은 굉장히 많다. 그의 대표적인 예시로는 gated graph neural network (GGNN), graph convolutional network (GCN) 등이 있다. 이 각각에 대한 글은 언젠가 미래에..
이번 글에서는 GNN의 가장 기본이 되는 프레임워크인 message passing neural network가 무엇이고, 그 장점과 예시에는 어떤 것이 있는지에 대해 알아보았다.
블로그 햇수로 4년차인데 수식 쓰는 법을 이제야 알아서 처음으로 문장 안에 수식을 제대로 넣어보았다.. ㅎㅎ
References
1. Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., & Dahl, G. E. (2017, July). Neural message passing for quantum chemistry. In International conference on machine learning (pp. 1263-1272). PMLR.
2. Hamilton, W. L. Graph Representation Learning.
최근댓글