Created
September 3, 2020 18:37
-
-
Save sparticlesteve/62854712aed7a7e46b70efaec0c64e4f to your computer and use it in GitHub Desktop.
Modified graph conv LSTM example showing graph sequence data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import random | |
import numpy as np | |
import networkx as nx | |
import torch.nn.functional as F | |
from torch_geometric_temporal.nn.recurrent import GConvLSTM | |
def create_mock_data(number_of_nodes, edge_per_node, in_channels): | |
""" | |
Creating a mock feature matrix and edge index. | |
""" | |
graph = nx.watts_strogatz_graph(number_of_nodes, edge_per_node, 0.5) | |
edge_index = torch.LongTensor(np.array([edge for edge in graph.edges()]).T) | |
X = torch.FloatTensor(np.random.uniform(-1, 1, (number_of_nodes, in_channels))) | |
return X, edge_index | |
def create_mock_edge_weight(edge_index): | |
""" | |
Creating a mock edge weight tensor. | |
""" | |
return torch.FloatTensor(np.random.uniform(0, 1, (edge_index.shape[1]))) | |
def create_mock_target(number_of_nodes, number_of_classes): | |
""" | |
Creating a mock target vector. | |
""" | |
return torch.LongTensor([random.randint(0, number_of_classes-1) for node in range(number_of_nodes)]) | |
class RecurrentGCN(torch.nn.Module): | |
def __init__(self, node_features, num_classes): | |
super(RecurrentGCN, self).__init__() | |
# Documentation for GConvLSTM: | |
# https://pytorch-geometric-temporal.readthedocs.io/en/latest/modules/root.html#torch_geometric_temporal.nn.recurrent.gconv_lstm.GConvLSTM | |
self.recurrent_1 = GConvLSTM(node_features, 32, 5) | |
self.recurrent_2 = GConvLSTM(32, 16, 5) | |
self.linear = torch.nn.Linear(16, num_classes) | |
def forward(self, graphs): | |
# Process the sequence of graphs with our 2 GConvLSTM layers | |
# Initialize hidden and cell states to None so they are properly | |
# initialized automatically in the GConvLSTM layers. | |
h1, c1, h2, c2 = None, None, None, None | |
for x, edge_index, edge_weight in graphs: | |
h1, c1 = self.recurrent_1(x, edge_index, edge_weight, H=h1, C=c1) | |
# Feed hidden state output of first layer to the 2nd layer | |
h2, c2 = self.recurrent_2(h1, edge_index, edge_weight, H=h2, C=c2) | |
# Use the final hidden state output of 2nd recurrent layer for input to classifier | |
x = F.relu(h2) | |
x = F.dropout(x, training=self.training) | |
x = self.linear(x) | |
return F.log_softmax(x, dim=1) | |
node_features = 100 | |
node_count = 1000 | |
num_classes = 10 | |
sequence_len = 4 | |
edge_per_node = 15 | |
epochs = 200 | |
learning_rate = 0.01 | |
weight_decay = 5e-4 | |
print('Building model') | |
model = RecurrentGCN(node_features=node_features, num_classes=num_classes) | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) | |
model.train() | |
for epoch in range(epochs): | |
print('Step', epoch) | |
optimizer.zero_grad() | |
# Create a sequence of mock graphs | |
graphs = [] | |
for i in range(sequence_len): | |
x, edge_index = create_mock_data(node_count, edge_per_node, node_features) | |
edge_weight = create_mock_edge_weight(edge_index) | |
graphs.append((x, edge_index, edge_weight)) | |
# Create a mock target | |
target = create_mock_target(node_count, num_classes) | |
# Apply the model to the graph sequence | |
scores = model(graphs) | |
# Loss, optimizer | |
loss = F.nll_loss(scores, target) | |
loss.backward() | |
optimizer.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment