Last active
September 27, 2021 12:49
-
-
Save Flunzmas/0d35f67a3f5e73bdb952e1960b4b2388 to your computer and use it in GitHub Desktop.
PyG: Access individual graphs from a Batch object not created through Batch.from_data_list()
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch_geometric.data import Data as GraphData | |
# ... load training data | |
train_data = None | |
# uses the following DataLoader: https://gist.github.com/Flunzmas/5a5c8c8fd553609359704be3174db793 | |
data_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, drop_last=True) | |
for batch_idx, data in enumerate(data_loader): | |
for t, batch_at_timestep in enumerate(data): | |
# get batch_data and indices | |
batch_data = {x[0]: x[1] for x in iter(batch_at_timestep)} | |
batch_idx = batch_data.pop("batch") | |
unique_idx = batch_idx.unique() | |
B, BV = len(unique_idx), len(batch_idx) # batch_size, batch_size * |V| | |
# TODO assuming constant number of nodes in batch, and prob. even constant number of edges! | |
# Slice components of Batch object to get the individual data for a single graph. | |
# edge attribute tensors need special size treatment | |
batch_data["edge_index"] = batch_data["edge_index"].reshape(2, BV, -1).transpose(0, 1) | |
batch_data["edge_attr"] = batch_data["edge_attr"].reshape(BV, -1) | |
batch_data = {k: [v[batch_idx == i] for i in unique_idx] for k, v in batch_data.items()} # slice | |
batch_data["edge_index"] = [item.transpose(0, 1).reshape(2, -1) for item in batch_data["edge_index"]] | |
batch_data["edge_attr"] = [item.reshape(-1) for item in batch_data["edge_attr"]] | |
# Construct this timestep's individual graphs from sliced batch data | |
graphs_t, node_start_idx = [], 0 | |
for b in range(B): | |
# revert batch aggregation in edge_index by subtracting current start node idx | |
graph_t_b = GraphData(x=batch_data["x"][b], | |
edge_index=batch_data["edge_index"][b] - node_start_idx, | |
edge_attr=batch_data["edge_attr"][b], | |
y=batch_data["y"][b]) | |
graphs_t.append(graph_t_b) | |
node_start_idx += batch_data["x"][b].shape[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment