Skip to content

Instantly share code, notes, and snippets.

@MichelNivard
Last active April 4, 2025 07:51
Show Gist options
  • Save MichelNivard/bb07743b5ebb182934cccd0e7b7cc860 to your computer and use it in GitHub Desktop.
Save MichelNivard/bb07743b5ebb182934cccd0e7b7cc860 to your computer and use it in GitHub Desktop.
import os
import numpy as np
import sidechainnet as scn
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
# === PARAMETERS ===
output_dir = "contact_maps_128"
max_residues = 512
final_image_size = (256, 256)
contact_threshold = 8.0 # only used if you want binary maps
scale_to_255 = True
distance_clip = 20.0 # max Å distance
os.makedirs(output_dir, exist_ok=True)
# === FUNCTIONS ===
def make_contact_map(coords, binary=False, threshold=8.0):
"""
Computes a contact or distance map from 3D Cα coordinates.
Args:
coords (np.ndarray): Shape (L, 3) or (L, A, 3), 3D coordinates.
binary (bool): Whether to return binary contact map.
threshold (float): Å distance cutoff for binary maps.
clip_dist (float): Max distance for clipping and scaling.
Returns:
np.ndarray: (L, L) contact or scaled distance map in uint8.
"""
# extract only Cα coordinates (atom 0 for each residue)
ca_coords = coords[:, 0, :] # shape (L, 3)
n_residues = ca_coords.shape[0]
n_missing = np.isnan(ca_coords).any(axis=1).sum()
frac_missing = n_missing / n_residues
if frac_missing > 0.10:
print(f"{protein_id} skipped: {n_missing}/{n_residues} Cα coords missing ({frac_missing:.1%})")
return None
if np.unique(ca_coords, axis=0).shape[0] <= 1:
print(f"{protein_id} skipped: collapsed structure (identical Cα)")
return None
if n_residues < 10:
print(f"{protein_id} skipped: too short ({n_residues} residues)")
return None
# normalize based on first Cα
ca_coords -= ca_coords[0] # shift so residue 0 is at (0, 0, 0)
# compute pairwise distance matrix
dists = np.linalg.norm(ca_coords[:, None, :] - ca_coords[None, :, :], axis=-1)
if binary:
contact_map = (dists < threshold).astype(np.uint8) * 255
return contact_map
else:
levels = [0,20,40,60,80,100,120,140,160,180]
contact_map = np.zeros_like(dists, dtype=np.uint8) + 255
# Bin 0: d < threshold - 3
mask = dists < (threshold - 3)
contact_map[mask] = levels[0]
# Bins 1 to 9: 1Å slices
for i in range(1, 9):
lower = threshold - 5 + (i)
upper = lower + 1
mask = (dists >= lower) & (dists < upper)
contact_map[mask] = levels[i]
return contact_map
def resize_map(contact_map, final_size=(256, 256)):
img = Image.fromarray(contact_map)
img = img.resize(final_size, resample=Image.BILINEAR)
return img
def embed_and_resize_map(contact_map, max_size=512, final_size=(256, 256)):
n = contact_map.shape[0]
canvas = np.full((max_size, max_size), fill_value=0, dtype=np.uint8)
canvas[:n, :n] = contact_map
img = Image.fromarray(canvas)
img = img.resize(final_size, resample=Image.BILINEAR)
return img
def crop_contact_map(contact_map, crop_size=256, pad_value=255, min_size=20):
"""
Crop the top-left corner of the contact map to (crop_size, crop_size).
Pads with `pad_value` if the map is smaller.
Skips maps smaller than `min_size`.
"""
h, w = contact_map.shape
if h < min_size or w < min_size:
raise ValueError(f"Contact map too small: {h}x{w} (min required: {min_size})")
canvas = np.full((crop_size, crop_size), pad_value, dtype=np.uint8)
crop_h = min(h, crop_size)
crop_w = min(w, crop_size)
canvas[:crop_h, :crop_w] = contact_map[:crop_h, :crop_w]
return Image.fromarray(canvas)
# === MAIN ===
print("Loading SideChainNet CASP12 dataset...")
data = scn.load(casp_version=12)
#data = scn.load("debug")
print("Generating contact maps...")
for i, sample in enumerate(data):
try:
protein_id = sample.id
coords = sample.coords # shape: (L, A, 3)
if coords is None or coords.shape[0] < 2:
continue
# Compute distance matrix and preprocess
distance_map = make_contact_map(coords, binary=False)
# Skip if it's completely white
if np.all(distance_map == 255):
print(f"Skipping {protein_id}: distance map is all white")
continue
#img = embed_and_resize_map(distance_map, max_size=max_residues, final_size=final_image_size)
#img = resize_map(distance_map, final_size=final_image_size)
img = crop_contact_map(distance_map, crop_size=128,pad_value=255)
img.save(os.path.join(output_dir, f"{protein_id}.jpg"), format='JPEG')
if i % 100 == 0:
print(f"[{i}] Saved: {protein_id}.jpg")
except Exception as e:
print(f"Skipping {sample.id} due to error: {e}")
print("✅ Done.")
https://www.dropbox.com/scl/fo/w0e5wk62nm3pn9kvl888i/AGe-G-ggFgKef5t0BEKndV0?rlkey=ix282fqd6vd04vleldr5m85oo&dl=0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment