728x90

 이번 글에서는 PyG를 이용해서 GNN 모델링을 하려면 어떤 흐름으로 모델링을 진행해야 하는지에 대해 간략하게 알아볼 것이다. PyG의 설치 방법은 다음 글을 참고하면 된다.

* 이 글은 pyG 2.1.0 버전을 기준으로 작성되었습니다. 이 글은 파이썬에 대한 기본 지식 및 GNN의 기본 용어에 대한 이해가 있다는 가정 하에 진행됩니다.

Library / module import

import os
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx

Data import

 우선, 데이터를 불러온다. 이 튜토리얼에서는 GNN의 가장 기본적인 예시 데이터 중 하나인 Karate club 데이터를 활용한다. 이 데이터는 34명의 Karate 클럽 멤버들 각각을 하나의 노드, 이들이 사적으로 교류가 있었다면 이들 사이에 엣지로 표현한 그래프 데이터로, 각 노드는 네 개의 커뮤니티 중 어떤 커뮤니티에 속해있는지의 label을 가진다. 아래 코드를 통해 데이터를 불러온다.

dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

 이 때, dataset의 길이는 그래프의 개수, num_features는 노드 feature의 개수, num_classess는 label의 개수를 의미한다. 위 코드를 실행했을 때 다음과 같은 결과가 나오면 성공이다. 즉, 해당 데이터셋에는 한 개의 그래프가 포함되어 있고, 이 그래프가 가지는 feature의 개수는 34개이며 label의 개수는 4인 것이다.

결과

이외에도 하나의 그래프는 노드의 개수, 엣지의 개수 등 다양한 특성을 가진다. 이들을  살펴보기 위해서 다음 코드를 실행한다.

data = dataset[0]  # Get the first graph object.

print(data)
print('==============================================================')

# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')

다음과 같은 결과가 나오면 된다. 즉, 이 그래프는 34개의 노드와 이 노드 사이에 156개의 엣지를 가지는 것이다.

결과

그렇다면 이 엣지들이 어떤 노드를 연결하고 있는지를 알아내기 위해서는 어떻게 해야할까? 이는 다음과 같이 edge_index를 살펴봄으로서 알 수 있다.

edge_index = data.edge_index
print(edge_index.t())

이를 실행하면 다음과 같은 결과가 도출된다.

엣지 연결 결과

 이 의미는 노드 0과 1, 노드 0과 2, 노드 0과 3, ..., 노드 0과 21이 연결되어 있다는 의미이다. 위 실행 결과에서 [0,20]은 포함되어 있지 않기 때문에, 노드 0과 노드 20은 연결되어 있지 않다. 이를 그림으로 살펴보기 위해서 우리는 다음과 같은 함수를 정의하고, 실행한다.

def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()
    
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

이를 통해 우리는 다음과 같이 그래프가 어떻게 연결되어 있는지를 한눈에 파악할 수 있다.

도출된 그래프

Modeling

 위 단계에서 우리는 PyG에서 그래프 데이터를 어떻게 import할 수 있는지와 더불어 PyG의 그래프 데이터의 형태에 대해 간략하게 살펴보았다. 이제 우리는 각 노드가 어떤 커뮤니티에 속해있는지의 노드 label을 GNN을 통해 예측해야 한다. 이를 위해서 우리는 GCN 모델을 사용하고, 그 코드는 아래와 같다.

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  # Final GNN embedding space.
        
        # Apply a final (linear) classifier.
        out = self.classifier(h)

        return out, h

model = GCN()
print(model)

 위 코드를 실행하면 다음과 같은 결과가 도출된다.

모델링 결과

위 모델은 세 개의 GCNConv layer를 통과하여 34->4->4->2라는 단계를 거쳐 노드의 feature를 2차원으로 매핑하는 모델로, 각 layer는 tanh 함수를 통과하여 output을 도출한다. 마지막으로, 이 embedding이 MLP classifier를 통과하여 커뮤니티 label을 예측한다.

 이 모델을 우리의 데이터셋에 적용해서 노드의 label을 예측하기 위해서는 다음과 같은 코드를 실행한다.

model = GCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    out, h = model(data.x, data.edge_index)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss, h

for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        print(loss)
        time.sleep(0.3)

이런 학습 과정을 거쳐서 점점 감소되는 loss가 도출되면 성공이다. 

 

 이번 글에서는 PyG를 어떻게 사용하면 되는지에 대해 간략하게 살펴보았다. 다음 PyG 포스팅부터는 PyG의 각 기능들에 대해 좀 더 상세하게 알아보겠다.

 

References

https://colab.research.google.com/drive/1h3-vJGRVloF5zStxL5I0rSy4ZUPNsjy8?usp=sharing#scrollTo=etxOsz8QIbMO 

 

300x250
  • 네이버 블러그 공유하기
  • 네이버 밴드에 공유하기
  • 페이스북 공유하기
  • 카카오스토리 공유하기