Created
February 28, 2025 16:00
-
-
Save CraftsMan-Labs/a1bb5787d35cac81eeab7ee12dcae632 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_geometric.nn import GCNConv, global_mean_pool | |
from torch_geometric.data import Data | |
from rdkit import Chem | |
import numpy as np | |
# 1. Molecular graph construction | |
def molecule_to_graph(smiles): | |
"""Convert SMILES string to a PyTorch Geometric graph""" | |
mol = Chem.MolFromSmiles(smiles) | |
# Get atom features | |
atom_features = [] | |
for atom in mol.GetAtoms(): | |
features = [ | |
atom.GetAtomicNum(), | |
atom.GetDegree(), | |
atom.GetFormalCharge(), | |
atom.GetHybridization().real, | |
atom.GetIsAromatic(), | |
atom.GetMass() | |
] | |
atom_features.append(features) | |
x = torch.tensor(atom_features, dtype=torch.float) | |
# Get edge indices | |
edge_indices = [] | |
for bond in mol.GetBonds(): | |
i = bond.GetBeginAtomIdx() | |
j = bond.GetEndAtomIdx() | |
# Add edges in both directions | |
edge_indices.append([i, j]) | |
edge_indices.append([j, i]) | |
edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() | |
return Data(x=x, edge_index=edge_index) | |
# 2. Graph Neural Network for drug molecules | |
class DrugGNN(nn.Module): | |
def __init__(self, input_dim, hidden_dim, output_dim): | |
super(DrugGNN, self).__init__() | |
self.conv1 = GCNConv(input_dim, hidden_dim) | |
self.conv2 = GCNConv(hidden_dim, hidden_dim) | |
self.conv3 = GCNConv(hidden_dim, output_dim) | |
def forward(self, data): | |
x, edge_index, batch = data.x, data.edge_index, data.batch | |
# Graph convolutions | |
x = F.relu(self.conv1(x, edge_index)) | |
x = F.relu(self.conv2(x, edge_index)) | |
x = self.conv3(x, edge_index) | |
# Global pooling | |
x = global_mean_pool(x, batch) # [batch_size, output_dim] | |
return x | |
# 3. Protein encoding using sequence information | |
class ProteinEncoder(nn.Module): | |
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim): | |
super(ProteinEncoder, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True) | |
self.fc = nn.Linear(hidden_dim * 2, output_dim) | |
def forward(self, x): | |
embedded = self.embedding(x) | |
output, (hidden, _) = self.lstm(embedded) | |
# Use the final hidden state from both directions | |
cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1) | |
x = self.fc(cat) | |
return x | |
# 4. Drug-Target Interaction model | |
class GraphDTA(nn.Module): | |
def __init__(self, drug_input_dim, drug_hidden_dim, drug_output_dim, | |
protein_vocab_size, protein_embedding_dim, protein_hidden_dim, protein_output_dim): | |
super(GraphDTA, self).__init__() | |
# Drug GNN | |
self.drug_gnn = DrugGNN(drug_input_dim, drug_hidden_dim, drug_output_dim) | |
# Protein encoder | |
self.protein_encoder = ProteinEncoder(protein_vocab_size, protein_embedding_dim, | |
protein_hidden_dim, protein_output_dim) | |
# Combined layers | |
combined_dim = drug_output_dim + protein_output_dim | |
self.fc1 = nn.Linear(combined_dim, 1024) | |
self.fc2 = nn.Linear(1024, 512) | |
self.fc3 = nn.Linear(512, 1) | |
def forward(self, drug_data, protein_data): | |
# Get drug embedding | |
drug_embedding = self.drug_gnn(drug_data) | |
# Get protein embedding | |
protein_embedding = self.protein_encoder(protein_data) | |
# Concatenate the embeddings | |
combined = torch.cat((drug_embedding, protein_embedding), dim=1) | |
# Final prediction layers | |
x = F.relu(self.fc1(combined)) | |
x = F.dropout(x, p=0.1, training=self.training) | |
x = F.relu(self.fc2(x)) | |
x = self.fc3(x) | |
return x | |
# Example usage | |
def main(): | |
# Example drug SMILES string (Aspirin) | |
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O" | |
# Convert to graph | |
drug_graph = molecule_to_graph(smiles) | |
# Create a batch with just this single example | |
from torch_geometric.data import Batch | |
drug_batch = Batch.from_data_list([drug_graph]) | |
# Dummy protein sequence (encoded as integers) | |
protein_seq = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=torch.long) | |
# Model parameters | |
drug_input_dim = 6 # Number of atom features | |
drug_hidden_dim = 128 | |
drug_output_dim = 128 | |
protein_vocab_size = 20 # 20 standard amino acids | |
protein_embedding_dim = 128 | |
protein_hidden_dim = 128 | |
protein_output_dim = 128 | |
# Create the model | |
model = GraphDTA(drug_input_dim, drug_hidden_dim, drug_output_dim, | |
protein_vocab_size, protein_embedding_dim, protein_hidden_dim, protein_output_dim) | |
# Forward pass | |
binding_affinity = model(drug_batch, protein_seq) | |
print(f"Predicted binding affinity: {binding_affinity.item()}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment