Created
December 30, 2024 22:51
-
-
Save hgbrian/1262066e680fc82dcb98e60449899ff9 to your computer and use it in GitHub Desktop.
Regress Kd against other features for data from Adaptyv round 2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import polars as pl | |
import numpy as np | |
from sklearn.linear_model import LinearRegression | |
from sklearn.ensemble import RandomForestRegressor | |
from sklearn.svm import SVR | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.model_selection import KFold | |
from sklearn.metrics import r2_score, mean_squared_error | |
from sklearn.inspection import permutation_importance | |
import xgboost as xgb | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from sklearn.model_selection import GroupKFold | |
from sklearn.dummy import DummyRegressor | |
from scipy.stats import spearmanr | |
USE_AVERAGED_REPLICATE_DATA = False | |
USE_REPLICATE_DATA = False | |
USE_NONBINDERS = True | |
USE_SIMILARITY_CHECK = False | |
def evaluate_prediction_accuracy(y_true, y_pred, target_name, labels=None): | |
r2 = r2_score(y_true, y_pred) | |
rmse = np.sqrt(mean_squared_error(y_true, y_pred)) | |
fold_errors = 10**np.abs(y_true - y_pred) | |
median_fold_error = np.median(fold_errors) | |
print(f"\n{target_name} Prediction Accuracy:") | |
print(f"R² Score: {r2:.3f}") | |
print(f"RMSE (log units): {rmse:.3f}") | |
print(f"Median fold error: {median_fold_error:.1f}x") | |
plt.figure(figsize=(10, 10)) | |
# Create scatter plot | |
plt.scatter(y_true, y_pred, alpha=0.5) | |
# Add labels if provided | |
if labels is not None: | |
for i, label in enumerate(labels): | |
plt.annotate(label, | |
(y_true[i], y_pred[i]), | |
xytext=(5, 5), | |
textcoords='offset points', | |
fontsize=8, | |
alpha=0.7) | |
# Add diagonal line | |
plt.plot([y_true.min(), y_true.max()], | |
[y_true.min(), y_true.max()], | |
'r--', | |
label='Perfect Prediction') | |
plt.xlabel(f'Actual {target_name}') | |
plt.ylabel(f'Predicted {target_name}') | |
plt.title(f'{target_name} Prediction Performance') | |
plt.grid(True) | |
plt.legend() | |
plt.tight_layout() | |
plt.savefig(f'prediction_performance_{target_name}.png', dpi=300, bbox_inches='tight') | |
plt.close() | |
return { | |
'r2': r2, | |
'rmse': rmse, | |
'median_fold_error': median_fold_error | |
} | |
def evaluate_models(X, y, target_name, labels, groups): | |
lr_model = LinearRegression() | |
rf_model = RandomForestRegressor(n_estimators=100, random_state=42) | |
svm_model = SVR(kernel='rbf', C=1.0, epsilon=0.1) | |
xgb_model = xgb.XGBRegressor( | |
objective='reg:squarederror', | |
n_estimators=100, | |
random_state=42, | |
enable_categorical=False | |
) | |
baseline_model = DummyRegressor(strategy='mean') | |
scaler = StandardScaler() | |
X_scaled = scaler.fit_transform(X) | |
# Use GroupKFold instead of KFold | |
gkf = GroupKFold(n_splits=5) | |
lr_pred = np.zeros_like(y) | |
rf_pred = np.zeros_like(y) | |
svm_pred = np.zeros_like(y) | |
xgb_pred = np.zeros_like(y) | |
baseline_pred = np.zeros_like(y) | |
for train_idx, val_idx in gkf.split(X_scaled, y, groups=groups): | |
X_train, X_val = X_scaled[train_idx], X_scaled[val_idx] | |
y_train, y_val = y[train_idx], y[val_idx] | |
lr_model.fit(X_train, y_train) | |
rf_model.fit(X_train, y_train) | |
svm_model.fit(X_train, y_train) | |
xgb_model.fit(X_train, y_train) | |
baseline_model.fit(X_train, y_train) | |
lr_pred[val_idx] = lr_model.predict(X_val) | |
rf_pred[val_idx] = rf_model.predict(X_val) | |
svm_pred[val_idx] = svm_model.predict(X_val) | |
xgb_pred[val_idx] = xgb_model.predict(X_val) | |
baseline_pred[val_idx] = baseline_model.predict(X_val) | |
# Final fit on all data | |
lr_model.fit(X_scaled, y) | |
rf_model.fit(X_scaled, y) | |
svm_model.fit(X_scaled, y) | |
xgb_model.fit(X_scaled, y) | |
# Get feature importance for SVM using permutation importance | |
perm_importance = permutation_importance(svm_model, X_scaled, y, n_repeats=10, random_state=42) | |
svm_feature_importance = perm_importance.importances_mean | |
return { | |
'baseline': { | |
'model': baseline_model, | |
'predictions': baseline_pred, | |
'metrics': evaluate_prediction_accuracy(y, baseline_pred, f"{target_name}_baseline", labels) | |
}, | |
'linear': { | |
'model': lr_model, | |
'coefficients': lr_model.coef_, | |
'predictions': lr_pred, | |
'metrics': evaluate_prediction_accuracy(y, lr_pred, f"{target_name}_linear", labels) | |
}, | |
'rf': { | |
'model': rf_model, | |
'feature_importance': rf_model.feature_importances_, | |
'predictions': rf_pred, | |
'metrics': evaluate_prediction_accuracy(y, rf_pred, f"{target_name}_rf", labels) | |
}, | |
'svm': { | |
'model': svm_model, | |
'feature_importance': svm_feature_importance, | |
'predictions': svm_pred, | |
'metrics': evaluate_prediction_accuracy(y, svm_pred, f"{target_name}_svm", labels) | |
}, | |
'xgb': { | |
'model': xgb_model, | |
'feature_importance': xgb_model.feature_importances_, | |
'predictions': xgb_pred, | |
'metrics': evaluate_prediction_accuracy(y, xgb_pred, f"{target_name}_xgb", labels) | |
} | |
} | |
# ------------------------------------------------------------------------------------------------------ | |
# Get replicate data | |
# | |
replicate_df = pl.read_csv("replicate_summary.csv", schema_overrides={"binding": pl.String}) | |
if USE_AVERAGED_REPLICATE_DATA: | |
averaged_df = replicate_df.group_by("name").agg([ | |
pl.col("kd").mean().alias("kd"), | |
pl.col("kon").mean().alias("kon"), | |
pl.col("koff").mean().alias("koff"), | |
# Keep other columns from first occurrence | |
pl.col("binding").first(), | |
pl.col("binding_strength").first(), | |
pl.col("expression").first() | |
]) | |
# Replace the original replicate_df with the averaged version | |
replicate_df = averaged_df | |
# ------------------------------------------------------------------------------------------------------ | |
# Get main results data | |
# | |
result_df = (pl.read_csv("result_summary.csv", schema_overrides={"binding": pl.String}) | |
.filter(pl.col("name") != "Cetuximab_scFv") | |
.filter(pl.col("name") != "Human_EGF") | |
.with_columns(seq_len = pl.col("sequence").str.len_chars()) | |
# add rank features | |
.with_row_index("rank", offset=1) | |
.sort(by="esm_pll", descending=True) | |
.with_row_index("rank_pll", offset=1) | |
.sort(by="iptm", descending=True) | |
.with_row_index("rank_iptm", offset=1) | |
.sort(by="pae_interaction", descending=False) | |
.with_row_index("rank_pae", offset=1) | |
.with_columns(vrank_=pl.col("rank_pll") + pl.col("rank_iptm") + pl.col("rank_pae")) | |
.with_columns(rank_virtual=pl.col("vrank_").rank(method="ordinal")) | |
.drop("vrank_") | |
.sort(by="rank_virtual", descending=False) | |
) | |
spearman_corr, p_value = spearmanr(result_df['rank'], result_df['rank_virtual']) | |
print(f"Spearman correlation (all): {spearman_corr:.3f}, p-value: {p_value:.3f}") | |
spearman_corr, p_value = spearmanr(result_df[:200]['rank'], result_df[:200]['rank_virtual']) | |
print(f"Spearman correlation (top 200): {spearman_corr:.3f}, p-value: {p_value:.3f}") | |
if USE_NONBINDERS: | |
result_df = (result_df | |
.with_columns(kd = pl.col("kd").fill_null(1e-4))) | |
# optionally, combine with replicate data to get koff and kon | |
if USE_REPLICATE_DATA: | |
combined_df = replicate_df.join(result_df, on="name", how="inner") | |
targets = ['kon', 'koff', 'kd'] | |
else: | |
combined_df = result_df | |
targets = ["kd"] | |
# Use `uv run --with prodigy-prot prodigy file.pdb`` | |
prodigy_kds = pl.read_csv("prodigy_kds.tsv", separator='\t') | |
combined_df = combined_df.join(prodigy_kds, on="name") | |
# Define features and targets | |
features = ['pae_interaction', 'esm_pll', 'iptm', 'plddt', "seq_len", "prodigy_kd"] | |
if USE_SIMILARITY_CHECK: | |
features += ['similarity_check'] | |
# Log transform the target variables | |
combined_df = combined_df.with_columns([pl.col(targets).add(1e-20).log10()]) | |
# First, convert string representation of list to actual list and get all unique categories | |
unique_models = set() | |
design_models = combined_df.select('design_models').to_numpy().ravel() | |
for models in design_models: | |
if models: # Check if not empty | |
model_list = eval(models) # Convert string representation to list | |
unique_models.update(model_list) | |
print("Unique models found:", unique_models) | |
# Create binary columns for each unique model | |
for model in unique_models: | |
# Create column name | |
col_name = f"design_model_{model}" | |
# Create binary column: 1 if model is in the list, 0 if not | |
combined_df = combined_df.with_columns([ | |
pl.col('design_models').map_elements( | |
lambda x: 1 if x and model in eval(x) else 0, | |
return_dtype=pl.Int64 | |
).alias(col_name) | |
]) | |
# Now add these new column names to features list | |
model_features = [f"design_model_{model}" for model in unique_models] | |
print("Added features:", model_features) | |
features.extend(model_features) | |
# ------------------------------------------------------------------------------------------------------ | |
# Regress each target separately | |
# | |
results = {} | |
for target in targets: | |
print(f"\nProcessing {target}") | |
clean_df = combined_df.filter( | |
~pl.col(target).is_null() & | |
~pl.col(target).is_nan() | |
) | |
clean_df = clean_df.fill_null(0) | |
X = clean_df.select(features).to_numpy() | |
y = clean_df.select(target).to_numpy().ravel() | |
n_samples = len(y) | |
print(f"Number of clean samples: {n_samples}") | |
if n_samples >= 5: | |
labels = [l[:16] for l in clean_df.select('name').to_numpy().ravel()] | |
print(sorted(labels)) | |
# Use names as groups | |
groups = clean_df.select('name').to_numpy().ravel() | |
results[target] = evaluate_models(X, y, target, labels, groups) | |
# Create feature importance plots | |
plt.figure(figsize=(20, 5)) | |
# Linear Regression | |
plt.subplot(1, 4, 1) | |
importance = np.abs(results[target]['linear']['coefficients']) | |
sorted_idx = np.argsort(importance) | |
plt.barh(range(len(features)), importance[sorted_idx]) | |
plt.yticks(range(len(features)), [features[i] for i in sorted_idx]) | |
plt.title(f'Linear Regression\nFeature Importance for {target}\nCV Score: {results[target]["linear"]["metrics"]["r2"]:.3f}') | |
# Random Forest | |
plt.subplot(1, 4, 2) | |
importance = results[target]['rf']['feature_importance'] | |
sorted_idx = np.argsort(importance) | |
plt.barh(range(len(features)), importance[sorted_idx]) | |
plt.yticks(range(len(features)), [features[i] for i in sorted_idx]) | |
plt.title(f'Random Forest\nFeature Importance for {target}\nCV Score: {results[target]["rf"]["metrics"]["r2"]:.3f}') | |
# SVM | |
plt.subplot(1, 4, 3) | |
importance = results[target]['svm']['feature_importance'] | |
sorted_idx = np.argsort(importance) | |
plt.barh(range(len(features)), importance[sorted_idx]) | |
plt.yticks(range(len(features)), [features[i] for i in sorted_idx]) | |
plt.title(f'SVM\nFeature Importance for {target}\nCV Score: {results[target]["svm"]["metrics"]["r2"]:.3f}') | |
# XGBoost | |
plt.subplot(1, 4, 4) | |
importance = results[target]['xgb']['feature_importance'] | |
sorted_idx = np.argsort(importance) | |
plt.barh(range(len(features)), importance[sorted_idx]) | |
plt.yticks(range(len(features)), [features[i] for i in sorted_idx]) | |
plt.title(f'XGBoost\nFeature Importance for {target}\nCV Score: {results[target]["xgb"]["metrics"]["r2"]:.3f}') | |
plt.tight_layout() | |
plt.savefig(f'feature_importance_{target}.png') | |
plt.close() | |
else: | |
print(f"Insufficient samples for {target} (n={n_samples})") | |
# ------------------------------------------------------------------------------------------------------ | |
# Plotting PairGrid | |
# | |
valid_targets = list(results.keys()) | |
if valid_targets: | |
plot_features = features[:4] | |
plot_df = combined_df.select(plot_features + valid_targets) | |
for col in plot_features + valid_targets: | |
plot_df = plot_df.filter( | |
~pl.col(col).is_null() & | |
~pl.col(col).is_nan() | |
) | |
plot_df_pd = plot_df.to_pandas() | |
g = sns.PairGrid(plot_df_pd, | |
vars=plot_features + valid_targets, | |
diag_sharey=False) | |
g.map_upper(sns.regplot, scatter_kws={'alpha':0.5}, | |
line_kws={'color': 'red'}) | |
g.map_lower(sns.scatterplot, alpha=0.5) | |
g.map_diag(sns.kdeplot, fill=True) | |
plt.tight_layout() | |
plt.savefig('pair_grid_full.png', dpi=300, bbox_inches='tight') | |
plt.close() | |
g = sns.PairGrid(plot_df_pd, | |
x_vars=plot_features, | |
y_vars=valid_targets, | |
height=3) | |
g.map(sns.regplot, | |
scatter_kws={'alpha':0.5}, | |
line_kws={'color': 'red'}) | |
plt.tight_layout() | |
plt.savefig('pair_grid_focused.png', dpi=300, bbox_inches='tight') | |
plt.close() | |
# ------------------------------------------------------------------------------------------------------ | |
# Summary of results | |
# | |
print("\nOverall Performance Summary:") | |
for target in valid_targets: | |
print(f"\n{target.upper()}:") | |
for model in ['linear', 'rf', 'svm', 'xgb', 'baseline']: | |
print(f"\n{model.upper()} Regression:") | |
metrics = results[target][model]['metrics'] | |
print(f"R² Score: {metrics['r2']:.3f}") | |
print(f"RMSE (log units): {metrics['rmse']:.3f}") | |
print(f"Median fold error: {metrics['median_fold_error']:.1f}x") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment