Skip to content

Instantly share code, notes, and snippets.

@hgbrian
Created December 30, 2024 22:51
Show Gist options
  • Save hgbrian/1262066e680fc82dcb98e60449899ff9 to your computer and use it in GitHub Desktop.
Save hgbrian/1262066e680fc82dcb98e60449899ff9 to your computer and use it in GitHub Desktop.
Regress Kd against other features for data from Adaptyv round 2
import polars as pl
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.inspection import permutation_importance
import xgboost as xgb
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import GroupKFold
from sklearn.dummy import DummyRegressor
from scipy.stats import spearmanr
USE_AVERAGED_REPLICATE_DATA = False
USE_REPLICATE_DATA = False
USE_NONBINDERS = True
USE_SIMILARITY_CHECK = False
def evaluate_prediction_accuracy(y_true, y_pred, target_name, labels=None):
r2 = r2_score(y_true, y_pred)
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
fold_errors = 10**np.abs(y_true - y_pred)
median_fold_error = np.median(fold_errors)
print(f"\n{target_name} Prediction Accuracy:")
print(f"R² Score: {r2:.3f}")
print(f"RMSE (log units): {rmse:.3f}")
print(f"Median fold error: {median_fold_error:.1f}x")
plt.figure(figsize=(10, 10))
# Create scatter plot
plt.scatter(y_true, y_pred, alpha=0.5)
# Add labels if provided
if labels is not None:
for i, label in enumerate(labels):
plt.annotate(label,
(y_true[i], y_pred[i]),
xytext=(5, 5),
textcoords='offset points',
fontsize=8,
alpha=0.7)
# Add diagonal line
plt.plot([y_true.min(), y_true.max()],
[y_true.min(), y_true.max()],
'r--',
label='Perfect Prediction')
plt.xlabel(f'Actual {target_name}')
plt.ylabel(f'Predicted {target_name}')
plt.title(f'{target_name} Prediction Performance')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig(f'prediction_performance_{target_name}.png', dpi=300, bbox_inches='tight')
plt.close()
return {
'r2': r2,
'rmse': rmse,
'median_fold_error': median_fold_error
}
def evaluate_models(X, y, target_name, labels, groups):
lr_model = LinearRegression()
rf_model = RandomForestRegressor(n_estimators=100, random_state=42)
svm_model = SVR(kernel='rbf', C=1.0, epsilon=0.1)
xgb_model = xgb.XGBRegressor(
objective='reg:squarederror',
n_estimators=100,
random_state=42,
enable_categorical=False
)
baseline_model = DummyRegressor(strategy='mean')
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Use GroupKFold instead of KFold
gkf = GroupKFold(n_splits=5)
lr_pred = np.zeros_like(y)
rf_pred = np.zeros_like(y)
svm_pred = np.zeros_like(y)
xgb_pred = np.zeros_like(y)
baseline_pred = np.zeros_like(y)
for train_idx, val_idx in gkf.split(X_scaled, y, groups=groups):
X_train, X_val = X_scaled[train_idx], X_scaled[val_idx]
y_train, y_val = y[train_idx], y[val_idx]
lr_model.fit(X_train, y_train)
rf_model.fit(X_train, y_train)
svm_model.fit(X_train, y_train)
xgb_model.fit(X_train, y_train)
baseline_model.fit(X_train, y_train)
lr_pred[val_idx] = lr_model.predict(X_val)
rf_pred[val_idx] = rf_model.predict(X_val)
svm_pred[val_idx] = svm_model.predict(X_val)
xgb_pred[val_idx] = xgb_model.predict(X_val)
baseline_pred[val_idx] = baseline_model.predict(X_val)
# Final fit on all data
lr_model.fit(X_scaled, y)
rf_model.fit(X_scaled, y)
svm_model.fit(X_scaled, y)
xgb_model.fit(X_scaled, y)
# Get feature importance for SVM using permutation importance
perm_importance = permutation_importance(svm_model, X_scaled, y, n_repeats=10, random_state=42)
svm_feature_importance = perm_importance.importances_mean
return {
'baseline': {
'model': baseline_model,
'predictions': baseline_pred,
'metrics': evaluate_prediction_accuracy(y, baseline_pred, f"{target_name}_baseline", labels)
},
'linear': {
'model': lr_model,
'coefficients': lr_model.coef_,
'predictions': lr_pred,
'metrics': evaluate_prediction_accuracy(y, lr_pred, f"{target_name}_linear", labels)
},
'rf': {
'model': rf_model,
'feature_importance': rf_model.feature_importances_,
'predictions': rf_pred,
'metrics': evaluate_prediction_accuracy(y, rf_pred, f"{target_name}_rf", labels)
},
'svm': {
'model': svm_model,
'feature_importance': svm_feature_importance,
'predictions': svm_pred,
'metrics': evaluate_prediction_accuracy(y, svm_pred, f"{target_name}_svm", labels)
},
'xgb': {
'model': xgb_model,
'feature_importance': xgb_model.feature_importances_,
'predictions': xgb_pred,
'metrics': evaluate_prediction_accuracy(y, xgb_pred, f"{target_name}_xgb", labels)
}
}
# ------------------------------------------------------------------------------------------------------
# Get replicate data
#
replicate_df = pl.read_csv("replicate_summary.csv", schema_overrides={"binding": pl.String})
if USE_AVERAGED_REPLICATE_DATA:
averaged_df = replicate_df.group_by("name").agg([
pl.col("kd").mean().alias("kd"),
pl.col("kon").mean().alias("kon"),
pl.col("koff").mean().alias("koff"),
# Keep other columns from first occurrence
pl.col("binding").first(),
pl.col("binding_strength").first(),
pl.col("expression").first()
])
# Replace the original replicate_df with the averaged version
replicate_df = averaged_df
# ------------------------------------------------------------------------------------------------------
# Get main results data
#
result_df = (pl.read_csv("result_summary.csv", schema_overrides={"binding": pl.String})
.filter(pl.col("name") != "Cetuximab_scFv")
.filter(pl.col("name") != "Human_EGF")
.with_columns(seq_len = pl.col("sequence").str.len_chars())
# add rank features
.with_row_index("rank", offset=1)
.sort(by="esm_pll", descending=True)
.with_row_index("rank_pll", offset=1)
.sort(by="iptm", descending=True)
.with_row_index("rank_iptm", offset=1)
.sort(by="pae_interaction", descending=False)
.with_row_index("rank_pae", offset=1)
.with_columns(vrank_=pl.col("rank_pll") + pl.col("rank_iptm") + pl.col("rank_pae"))
.with_columns(rank_virtual=pl.col("vrank_").rank(method="ordinal"))
.drop("vrank_")
.sort(by="rank_virtual", descending=False)
)
spearman_corr, p_value = spearmanr(result_df['rank'], result_df['rank_virtual'])
print(f"Spearman correlation (all): {spearman_corr:.3f}, p-value: {p_value:.3f}")
spearman_corr, p_value = spearmanr(result_df[:200]['rank'], result_df[:200]['rank_virtual'])
print(f"Spearman correlation (top 200): {spearman_corr:.3f}, p-value: {p_value:.3f}")
if USE_NONBINDERS:
result_df = (result_df
.with_columns(kd = pl.col("kd").fill_null(1e-4)))
# optionally, combine with replicate data to get koff and kon
if USE_REPLICATE_DATA:
combined_df = replicate_df.join(result_df, on="name", how="inner")
targets = ['kon', 'koff', 'kd']
else:
combined_df = result_df
targets = ["kd"]
# Use `uv run --with prodigy-prot prodigy file.pdb``
prodigy_kds = pl.read_csv("prodigy_kds.tsv", separator='\t')
combined_df = combined_df.join(prodigy_kds, on="name")
# Define features and targets
features = ['pae_interaction', 'esm_pll', 'iptm', 'plddt', "seq_len", "prodigy_kd"]
if USE_SIMILARITY_CHECK:
features += ['similarity_check']
# Log transform the target variables
combined_df = combined_df.with_columns([pl.col(targets).add(1e-20).log10()])
# First, convert string representation of list to actual list and get all unique categories
unique_models = set()
design_models = combined_df.select('design_models').to_numpy().ravel()
for models in design_models:
if models: # Check if not empty
model_list = eval(models) # Convert string representation to list
unique_models.update(model_list)
print("Unique models found:", unique_models)
# Create binary columns for each unique model
for model in unique_models:
# Create column name
col_name = f"design_model_{model}"
# Create binary column: 1 if model is in the list, 0 if not
combined_df = combined_df.with_columns([
pl.col('design_models').map_elements(
lambda x: 1 if x and model in eval(x) else 0,
return_dtype=pl.Int64
).alias(col_name)
])
# Now add these new column names to features list
model_features = [f"design_model_{model}" for model in unique_models]
print("Added features:", model_features)
features.extend(model_features)
# ------------------------------------------------------------------------------------------------------
# Regress each target separately
#
results = {}
for target in targets:
print(f"\nProcessing {target}")
clean_df = combined_df.filter(
~pl.col(target).is_null() &
~pl.col(target).is_nan()
)
clean_df = clean_df.fill_null(0)
X = clean_df.select(features).to_numpy()
y = clean_df.select(target).to_numpy().ravel()
n_samples = len(y)
print(f"Number of clean samples: {n_samples}")
if n_samples >= 5:
labels = [l[:16] for l in clean_df.select('name').to_numpy().ravel()]
print(sorted(labels))
# Use names as groups
groups = clean_df.select('name').to_numpy().ravel()
results[target] = evaluate_models(X, y, target, labels, groups)
# Create feature importance plots
plt.figure(figsize=(20, 5))
# Linear Regression
plt.subplot(1, 4, 1)
importance = np.abs(results[target]['linear']['coefficients'])
sorted_idx = np.argsort(importance)
plt.barh(range(len(features)), importance[sorted_idx])
plt.yticks(range(len(features)), [features[i] for i in sorted_idx])
plt.title(f'Linear Regression\nFeature Importance for {target}\nCV Score: {results[target]["linear"]["metrics"]["r2"]:.3f}')
# Random Forest
plt.subplot(1, 4, 2)
importance = results[target]['rf']['feature_importance']
sorted_idx = np.argsort(importance)
plt.barh(range(len(features)), importance[sorted_idx])
plt.yticks(range(len(features)), [features[i] for i in sorted_idx])
plt.title(f'Random Forest\nFeature Importance for {target}\nCV Score: {results[target]["rf"]["metrics"]["r2"]:.3f}')
# SVM
plt.subplot(1, 4, 3)
importance = results[target]['svm']['feature_importance']
sorted_idx = np.argsort(importance)
plt.barh(range(len(features)), importance[sorted_idx])
plt.yticks(range(len(features)), [features[i] for i in sorted_idx])
plt.title(f'SVM\nFeature Importance for {target}\nCV Score: {results[target]["svm"]["metrics"]["r2"]:.3f}')
# XGBoost
plt.subplot(1, 4, 4)
importance = results[target]['xgb']['feature_importance']
sorted_idx = np.argsort(importance)
plt.barh(range(len(features)), importance[sorted_idx])
plt.yticks(range(len(features)), [features[i] for i in sorted_idx])
plt.title(f'XGBoost\nFeature Importance for {target}\nCV Score: {results[target]["xgb"]["metrics"]["r2"]:.3f}')
plt.tight_layout()
plt.savefig(f'feature_importance_{target}.png')
plt.close()
else:
print(f"Insufficient samples for {target} (n={n_samples})")
# ------------------------------------------------------------------------------------------------------
# Plotting PairGrid
#
valid_targets = list(results.keys())
if valid_targets:
plot_features = features[:4]
plot_df = combined_df.select(plot_features + valid_targets)
for col in plot_features + valid_targets:
plot_df = plot_df.filter(
~pl.col(col).is_null() &
~pl.col(col).is_nan()
)
plot_df_pd = plot_df.to_pandas()
g = sns.PairGrid(plot_df_pd,
vars=plot_features + valid_targets,
diag_sharey=False)
g.map_upper(sns.regplot, scatter_kws={'alpha':0.5},
line_kws={'color': 'red'})
g.map_lower(sns.scatterplot, alpha=0.5)
g.map_diag(sns.kdeplot, fill=True)
plt.tight_layout()
plt.savefig('pair_grid_full.png', dpi=300, bbox_inches='tight')
plt.close()
g = sns.PairGrid(plot_df_pd,
x_vars=plot_features,
y_vars=valid_targets,
height=3)
g.map(sns.regplot,
scatter_kws={'alpha':0.5},
line_kws={'color': 'red'})
plt.tight_layout()
plt.savefig('pair_grid_focused.png', dpi=300, bbox_inches='tight')
plt.close()
# ------------------------------------------------------------------------------------------------------
# Summary of results
#
print("\nOverall Performance Summary:")
for target in valid_targets:
print(f"\n{target.upper()}:")
for model in ['linear', 'rf', 'svm', 'xgb', 'baseline']:
print(f"\n{model.upper()} Regression:")
metrics = results[target][model]['metrics']
print(f"R² Score: {metrics['r2']:.3f}")
print(f"RMSE (log units): {metrics['rmse']:.3f}")
print(f"Median fold error: {metrics['median_fold_error']:.1f}x")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment