Last active
April 4, 2025 07:51
-
-
Save MichelNivard/bb07743b5ebb182934cccd0e7b7cc860 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
import sidechainnet as scn | |
from denoising_diffusion_pytorch import Unet, GaussianDiffusion | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import torchvision.transforms as T | |
from PIL import Image | |
from pathlib import Path | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
# === PARAMETERS === | |
output_dir = "contact_maps_128" | |
max_residues = 512 | |
final_image_size = (256, 256) | |
contact_threshold = 8.0 # only used if you want binary maps | |
scale_to_255 = True | |
distance_clip = 20.0 # max Å distance | |
os.makedirs(output_dir, exist_ok=True) | |
# === FUNCTIONS === | |
def make_contact_map(coords, binary=False, threshold=8.0): | |
""" | |
Computes a contact or distance map from 3D Cα coordinates. | |
Args: | |
coords (np.ndarray): Shape (L, 3) or (L, A, 3), 3D coordinates. | |
binary (bool): Whether to return binary contact map. | |
threshold (float): Å distance cutoff for binary maps. | |
clip_dist (float): Max distance for clipping and scaling. | |
Returns: | |
np.ndarray: (L, L) contact or scaled distance map in uint8. | |
""" | |
# extract only Cα coordinates (atom 0 for each residue) | |
ca_coords = coords[:, 0, :] # shape (L, 3) | |
n_residues = ca_coords.shape[0] | |
n_missing = np.isnan(ca_coords).any(axis=1).sum() | |
frac_missing = n_missing / n_residues | |
if frac_missing > 0.10: | |
print(f"{protein_id} skipped: {n_missing}/{n_residues} Cα coords missing ({frac_missing:.1%})") | |
return None | |
if np.unique(ca_coords, axis=0).shape[0] <= 1: | |
print(f"{protein_id} skipped: collapsed structure (identical Cα)") | |
return None | |
if n_residues < 10: | |
print(f"{protein_id} skipped: too short ({n_residues} residues)") | |
return None | |
# normalize based on first Cα | |
ca_coords -= ca_coords[0] # shift so residue 0 is at (0, 0, 0) | |
# compute pairwise distance matrix | |
dists = np.linalg.norm(ca_coords[:, None, :] - ca_coords[None, :, :], axis=-1) | |
if binary: | |
contact_map = (dists < threshold).astype(np.uint8) * 255 | |
return contact_map | |
else: | |
levels = [0,20,40,60,80,100,120,140,160,180] | |
contact_map = np.zeros_like(dists, dtype=np.uint8) + 255 | |
# Bin 0: d < threshold - 3 | |
mask = dists < (threshold - 3) | |
contact_map[mask] = levels[0] | |
# Bins 1 to 9: 1Å slices | |
for i in range(1, 9): | |
lower = threshold - 5 + (i) | |
upper = lower + 1 | |
mask = (dists >= lower) & (dists < upper) | |
contact_map[mask] = levels[i] | |
return contact_map | |
def resize_map(contact_map, final_size=(256, 256)): | |
img = Image.fromarray(contact_map) | |
img = img.resize(final_size, resample=Image.BILINEAR) | |
return img | |
def embed_and_resize_map(contact_map, max_size=512, final_size=(256, 256)): | |
n = contact_map.shape[0] | |
canvas = np.full((max_size, max_size), fill_value=0, dtype=np.uint8) | |
canvas[:n, :n] = contact_map | |
img = Image.fromarray(canvas) | |
img = img.resize(final_size, resample=Image.BILINEAR) | |
return img | |
def crop_contact_map(contact_map, crop_size=256, pad_value=255, min_size=20): | |
""" | |
Crop the top-left corner of the contact map to (crop_size, crop_size). | |
Pads with `pad_value` if the map is smaller. | |
Skips maps smaller than `min_size`. | |
""" | |
h, w = contact_map.shape | |
if h < min_size or w < min_size: | |
raise ValueError(f"Contact map too small: {h}x{w} (min required: {min_size})") | |
canvas = np.full((crop_size, crop_size), pad_value, dtype=np.uint8) | |
crop_h = min(h, crop_size) | |
crop_w = min(w, crop_size) | |
canvas[:crop_h, :crop_w] = contact_map[:crop_h, :crop_w] | |
return Image.fromarray(canvas) | |
# === MAIN === | |
print("Loading SideChainNet CASP12 dataset...") | |
data = scn.load(casp_version=12) | |
#data = scn.load("debug") | |
print("Generating contact maps...") | |
for i, sample in enumerate(data): | |
try: | |
protein_id = sample.id | |
coords = sample.coords # shape: (L, A, 3) | |
if coords is None or coords.shape[0] < 2: | |
continue | |
# Compute distance matrix and preprocess | |
distance_map = make_contact_map(coords, binary=False) | |
# Skip if it's completely white | |
if np.all(distance_map == 255): | |
print(f"Skipping {protein_id}: distance map is all white") | |
continue | |
#img = embed_and_resize_map(distance_map, max_size=max_residues, final_size=final_image_size) | |
#img = resize_map(distance_map, final_size=final_image_size) | |
img = crop_contact_map(distance_map, crop_size=128,pad_value=255) | |
img.save(os.path.join(output_dir, f"{protein_id}.jpg"), format='JPEG') | |
if i % 100 == 0: | |
print(f"[{i}] Saved: {protein_id}.jpg") | |
except Exception as e: | |
print(f"Skipping {sample.id} due to error: {e}") | |
print("✅ Done.") | |
https://www.dropbox.com/scl/fo/w0e5wk62nm3pn9kvl888i/AGe-G-ggFgKef5t0BEKndV0?rlkey=ix282fqd6vd04vleldr5m85oo&dl=0 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment