Skip to content

Instantly share code, notes, and snippets.

@CraftsMan-Labs
Created February 28, 2025 16:00
Show Gist options
  • Save CraftsMan-Labs/a1bb5787d35cac81eeab7ee12dcae632 to your computer and use it in GitHub Desktop.
Save CraftsMan-Labs/a1bb5787d35cac81eeab7ee12dcae632 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
from rdkit import Chem
import numpy as np
# 1. Molecular graph construction
def molecule_to_graph(smiles):
"""Convert SMILES string to a PyTorch Geometric graph"""
mol = Chem.MolFromSmiles(smiles)
# Get atom features
atom_features = []
for atom in mol.GetAtoms():
features = [
atom.GetAtomicNum(),
atom.GetDegree(),
atom.GetFormalCharge(),
atom.GetHybridization().real,
atom.GetIsAromatic(),
atom.GetMass()
]
atom_features.append(features)
x = torch.tensor(atom_features, dtype=torch.float)
# Get edge indices
edge_indices = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
# Add edges in both directions
edge_indices.append([i, j])
edge_indices.append([j, i])
edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
return Data(x=x, edge_index=edge_index)
# 2. Graph Neural Network for drug molecules
class DrugGNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(DrugGNN, self).__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.conv3 = GCNConv(hidden_dim, output_dim)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
# Graph convolutions
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = self.conv3(x, edge_index)
# Global pooling
x = global_mean_pool(x, batch) # [batch_size, output_dim]
return x
# 3. Protein encoding using sequence information
class ProteinEncoder(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
super(ProteinEncoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, output_dim)
def forward(self, x):
embedded = self.embedding(x)
output, (hidden, _) = self.lstm(embedded)
# Use the final hidden state from both directions
cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
x = self.fc(cat)
return x
# 4. Drug-Target Interaction model
class GraphDTA(nn.Module):
def __init__(self, drug_input_dim, drug_hidden_dim, drug_output_dim,
protein_vocab_size, protein_embedding_dim, protein_hidden_dim, protein_output_dim):
super(GraphDTA, self).__init__()
# Drug GNN
self.drug_gnn = DrugGNN(drug_input_dim, drug_hidden_dim, drug_output_dim)
# Protein encoder
self.protein_encoder = ProteinEncoder(protein_vocab_size, protein_embedding_dim,
protein_hidden_dim, protein_output_dim)
# Combined layers
combined_dim = drug_output_dim + protein_output_dim
self.fc1 = nn.Linear(combined_dim, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 1)
def forward(self, drug_data, protein_data):
# Get drug embedding
drug_embedding = self.drug_gnn(drug_data)
# Get protein embedding
protein_embedding = self.protein_encoder(protein_data)
# Concatenate the embeddings
combined = torch.cat((drug_embedding, protein_embedding), dim=1)
# Final prediction layers
x = F.relu(self.fc1(combined))
x = F.dropout(x, p=0.1, training=self.training)
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Example usage
def main():
# Example drug SMILES string (Aspirin)
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
# Convert to graph
drug_graph = molecule_to_graph(smiles)
# Create a batch with just this single example
from torch_geometric.data import Batch
drug_batch = Batch.from_data_list([drug_graph])
# Dummy protein sequence (encoded as integers)
protein_seq = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=torch.long)
# Model parameters
drug_input_dim = 6 # Number of atom features
drug_hidden_dim = 128
drug_output_dim = 128
protein_vocab_size = 20 # 20 standard amino acids
protein_embedding_dim = 128
protein_hidden_dim = 128
protein_output_dim = 128
# Create the model
model = GraphDTA(drug_input_dim, drug_hidden_dim, drug_output_dim,
protein_vocab_size, protein_embedding_dim, protein_hidden_dim, protein_output_dim)
# Forward pass
binding_affinity = model(drug_batch, protein_seq)
print(f"Predicted binding affinity: {binding_affinity.item()}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment