Created
April 30, 2020 01:43
-
-
Save pbloem/f8667c1bfc75587bb054c080b3d6a988 to your computer and use it in GitHub Desktop.
RGCN implementation from scratch. Untested in gist form. Let me know if you need this for something.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, os, sys | |
from torch import nn | |
import torch.nn.functional as F | |
import torch.distributions as ds | |
from math import sqrt, ceil | |
import layers, util | |
import torch as T | |
class RGCNClassic(nn.Module): | |
""" | |
Classic RGCN | |
""" | |
def __init__(self, edges, n, numcls, emb=16, bases=None, softmax=False): | |
super().__init__() | |
self.emb = emb | |
self.bases = bases | |
self.numcls = numcls | |
self.softmax = softmax | |
# horizontally and vertically stacked versions of the adjacency graph | |
hor_ind, hor_size = adj(edges, n, vertical=False) | |
ver_ind, ver_size = adj(edges, n, vertical=True) | |
_, rn = hor_size | |
r = rn//n | |
t = len(edges[0][0]) | |
vals = torch.ones(ver_ind.size(0), dtype=torch.float) | |
vals = vals / util.sum_sparse(ver_ind, vals, ver_size) # row-normalize | |
# -- the values are the same for the horizontal and the vertically stacked adjacency matrices | |
# so we can just normalize them by the vertically stacked one and reuse for the horizontal | |
hor_graph = torch.sparse.FloatTensor(indices=hor_ind.t(), values=vals, size=hor_size) | |
self.register_buffer('hor_graph', hor_graph) | |
ver_graph = torch.sparse.FloatTensor(indices=ver_ind.t(), values=vals, size=ver_size) | |
self.register_buffer('ver_graph', ver_graph) | |
# layer 1 weights | |
if bases is None: | |
self.weights1 = nn.Parameter(torch.FloatTensor(r, n, emb)) | |
nn.init.xavier_uniform_(self.weights1, gain=nn.init.calculate_gain('relu')) | |
self.bases1 = None | |
else: | |
self.comps1 = nn.Parameter(torch.FloatTensor(r, bases)) | |
nn.init.xavier_uniform_(self.comps1, gain=nn.init.calculate_gain('relu')) | |
self.bases1 = nn.Parameter(torch.FloatTensor(bases, n, emb)) | |
nn.init.xavier_uniform_(self.bases1, gain=nn.init.calculate_gain('relu')) | |
# layer 2 weights | |
if bases is None: | |
self.weights2 = nn.Parameter(torch.FloatTensor(r, emb, numcls) ) | |
nn.init.xavier_uniform_(self.weights2, gain=nn.init.calculate_gain('relu')) | |
self.bases2 = None | |
else: | |
self.comps2 = nn.Parameter(torch.FloatTensor(r, bases)) | |
nn.init.xavier_uniform_(self.comps2, gain=nn.init.calculate_gain('relu')) | |
self.bases2 = nn.Parameter(torch.FloatTensor(bases, emb, numcls)) | |
nn.init.xavier_uniform_(self.bases2, gain=nn.init.calculate_gain('relu')) | |
self.bias1 = nn.Parameter(torch.FloatTensor(emb).zero_()) | |
self.bias2 = nn.Parameter(torch.FloatTensor(numcls).zero_()) | |
def forward(self): | |
## Layer 1 | |
n, rn = self.hor_graph.size() | |
r = rn // n | |
e = self.emb | |
b, c = self.bases, self.numcls | |
if self.bases1 is not None: | |
# weights = torch.einsum('rb, bij -> rij', self.comps1, self.bases1) | |
weights = torch.mm(self.comps1, self.bases1.view(b, n*e)).view(r, n, e) | |
else: | |
weights = self.weights1 | |
assert weights.size() == (r, n, e) | |
# Apply weights and sum over relations | |
h = torch.mm(self.hor_graph, weights.view(r*n, e)) | |
assert h.size() == (n, e) | |
h = F.relu(h + self.bias1) | |
## Layer 2 | |
# Multiply adjacencies by hidden | |
h = torch.mm(self.ver_graph, h) # sparse mm | |
h = h.view(r, n, e) # new dim for the relations | |
if self.bases2 is not None: | |
# weights = torch.einsum('rb, bij -> rij', self.comps2, self.bases2) | |
weights = torch.mm(self.comps2, self.bases2.view(b, e * c)).view(r, e, c) | |
else: | |
weights = self.weights2 | |
# Apply weights, sum over relations | |
# h = torch.einsum('rhc, rnh -> nc', weights, h) | |
h = torch.bmm(h, weights).sum(dim=0) | |
assert h.size() == (n, c) | |
if self.softmax: | |
return F.softmax(h + self.bias2, dim=1) | |
return h + self.bias2 #-- softmax is applied in the loss | |
def adj(edges, num_nodes, cuda=False, vertical=True): | |
""" | |
Computes a sparse adjacency matrix for the given graph (the adjacency matrices of all | |
relations are stacked vertically). | |
:param edges: Dictionary representing the edges | |
:param i2r: list of relations | |
:param i2n: list of nodes | |
:return: sparse tensor | |
""" | |
ST = torch.cuda.sparse.FloatTensor if cuda else torch.sparse.FloatTensor | |
r, n = len(edges.keys()), num_nodes | |
size = (r*n, n) if vertical else (n, r*n) | |
from_indices = [] | |
upto_indices = [] | |
for rel, (fr, to) in edges.items(): | |
offset = rel * n | |
if vertical: | |
fr = [offset + f for f in fr] | |
else: | |
to = [offset + t for t in to] | |
from_indices.extend(fr) | |
upto_indices.extend(to) | |
indices = torch.tensor([from_indices, upto_indices], dtype=torch.long, device=d(cuda)) | |
assert indices.size(1) == sum([len(ed[0]) for _, ed in edges.items()]) | |
assert indices[0, :].max() < size[0], f'{indices[0, :].max()}, {size}, {r}, {edges.keys()}' | |
assert indices[1, :].max() < size[1], f'{indices[1, :].max()}, {size}, {r}, {edges.keys()}' | |
return indices.t(), size |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment