Created
February 11, 2025 11:16
-
-
Save MartGro/fb5ea606d5977706c3746fb2a071ab64 to your computer and use it in GitHub Desktop.
rpy2 DESeq2 Differential expression with alternative hypothesis
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import rpy2.robjects as ro | |
from rpy2.robjects import pandas2ri, Formula | |
from rpy2.robjects.packages import importr | |
import pandas as pd | |
import numpy as np | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
# Enable conversion between pandas and R dataframes | |
pandas2ri.activate() | |
# Create output directory | |
Path("deseq_merged").mkdir(parents=True, exist_ok=True) | |
# Import required R packages | |
base = importr('base') | |
deseq = importr('DESeq2') | |
stats = importr('stats') | |
ashr = importr('ashr') | |
# Define time groups and their order for both Lysate and SN | |
sample_groups = { | |
'Lysate': { | |
'9h': ['Lysate1-9h', 'Lysate2-9h', 'Lysate3-9h'], | |
'11h45': ['Lysate4-11h45', 'Lysate5-11h45', 'Lysate6-11h45'], | |
'14h30': ['Lysate8-14h30', 'Lysate9-14h30'], | |
'17h30': ['Lysate10-17h30', 'Lysate11-17h30', 'Lysate12-17h30'], | |
'20h30': ['Lysate13-20h30', 'Lysate14-20h30', 'Lysate15-20h30'] | |
}, | |
'SN': { | |
'9h': ['SN1-9h', 'SN2-9h', 'SN3-9h'], | |
'11h45': ['SN4-11h45', 'SN5-11h45', 'SN6-11h45'], | |
'14h30': ['SN7-14h30', 'SN8-14h30', 'SN9-14h30'], | |
'17h30': ['SN10-17h30', 'SN11-17h30', 'SN12-17h30'], | |
'20h30': ['SN13-20h30', 'SN14-20h30', 'SN15-20h30'] | |
} | |
} | |
timepoints = list(sample_groups['Lysate'].keys()) | |
def plot_ma(res_df, save_prefix, is_shrunk=False): | |
"""Create MA plot for results""" | |
plt.figure(figsize=(8, 6)) | |
# Plot all points in grey | |
plt.scatter(np.log10(res_df['baseMean']), | |
res_df['log2FoldChange'], | |
c='grey', | |
alpha=0.5, | |
s=1, | |
rasterized=True) | |
# Plot significant points in red | |
sig_df = res_df[res_df['padj'] < 0.05] | |
plt.scatter(np.log10(sig_df['baseMean']), | |
sig_df['log2FoldChange'], | |
c='red', | |
alpha=0.5, | |
s=1) | |
# Add threshold lines at |LFC| = 1 | |
plt.axhline(y=1, color='blue', linestyle='--') | |
plt.axhline(y=-1, color='blue', linestyle='--') | |
plt.xlabel('log10(baseMean)') | |
plt.ylabel('log2FoldChange') | |
shrinkage_status = "shrunk" if is_shrunk else "unshrunk" | |
plt.title(f'{save_prefix}\ngreaterAbs (n={len(sig_df)}) - {shrinkage_status}') | |
plt.savefig(f'deseq_merged/MA_plot_{save_prefix}_{shrinkage_status}.svg', dpi=300) | |
plt.close() | |
def run_deseq2_comparison(time1, time2, modality): | |
save_prefix = f"{modality}_{time1}_vs_{time2}" | |
print(f"\nProcessing {save_prefix}...") | |
# Create subset for comparison | |
samples_columns = sample_groups[modality][time1] + sample_groups[modality][time2] | |
counts_subset = gene_level_df_merged_filtered[samples_columns].copy() | |
# Create sample info dataframe with unordered factor levels | |
conditions = [time1] * len(sample_groups[modality][time1]) + [time2] * len(sample_groups[modality][time2]) | |
# Create unordered categorical in pandas, but ensure time1 is first in categories | |
coldata = pd.DataFrame({ | |
'time': pd.Categorical(conditions, categories=[time1, time2], ordered=False) | |
}, index=samples_columns) | |
try: | |
# Convert to R objects | |
counts_matrix = pandas2ri.py2rpy(counts_subset) | |
coldata_r = pandas2ri.py2rpy(coldata) | |
# Create DESeqDataSet | |
dds = deseq.DESeqDataSetFromMatrix( | |
countData=counts_matrix, | |
colData=coldata_r, | |
design=Formula('~ time') | |
) | |
# Run DESeq | |
dds = deseq.DESeq(dds) | |
# Get available coefficients | |
coef_names = ro.r('resultsNames')(dds) | |
print("\nAvailable coefficients:", list(coef_names)) | |
# Use consistent contrast direction for both analyses | |
contrast = ro.StrVector(["time", time2, time1]) | |
# Get results with alternative hypothesis before shrinkage | |
res_unshrunk = deseq.results(dds, | |
contrast=contrast, | |
lfcThreshold=1, | |
altHypothesis="greaterAbs") | |
# Convert results to pandas | |
res_df_unshrunk = pandas2ri.rpy2py(base.as_data_frame(res_unshrunk)) | |
# Save unshrunk results | |
res_df_unshrunk.to_csv(f'deseq_merged/deseq2_{save_prefix}_unshrunk.csv') | |
# Create MA plot for unshrunk results | |
plot_ma(res_df_unshrunk, save_prefix, is_shrunk=False) | |
# Calculate unshrunk statistics | |
sig_genes_unshrunk = sum(res_df_unshrunk['padj'] < 0.05) | |
up_unshrunk = sum((res_df_unshrunk['padj'] < 0.05) & (res_df_unshrunk['log2FoldChange'] > 0)) | |
down_unshrunk = sum((res_df_unshrunk['padj'] < 0.05) & (res_df_unshrunk['log2FoldChange'] < 0)) | |
# Expected coefficient name based on reference level setting | |
expected_coef = f"time_{time2}_vs_{time1}" | |
# Verify the coefficient exists | |
if expected_coef not in coef_names: | |
raise ValueError(f"Expected coefficient '{expected_coef}' not found in {list(coef_names)}") | |
print(f"Using coefficient: {expected_coef}") | |
# Perform LFC shrinkage using the expected coefficient | |
res_shrunk = deseq.lfcShrink(dds, | |
coef=expected_coef, | |
res=res_unshrunk, | |
type="ashr") | |
# Convert shrunk results to pandas | |
res_df_shrunk = pandas2ri.rpy2py(base.as_data_frame(res_shrunk)) | |
# Save shrunk results | |
res_df_shrunk.to_csv(f'deseq_merged/deseq2_{save_prefix}_shrunk.csv') | |
# Create MA plot for shrunk results | |
plot_ma(res_df_shrunk, save_prefix, is_shrunk=True) | |
# Calculate shrunk statistics | |
sig_genes_shrunk = sum(res_df_shrunk['padj'] < 0.05) | |
up_shrunk = sum((res_df_shrunk['padj'] < 0.05) & (res_df_shrunk['log2FoldChange'] > 0)) | |
down_shrunk = sum((res_df_shrunk['padj'] < 0.05) & (res_df_shrunk['log2FoldChange'] < 0)) | |
# Print statistics | |
print(f"\nResults for {save_prefix}:") | |
print("\nBefore shrinkage:") | |
print(f"Significant genes (padj < 0.05): {sig_genes_unshrunk}") | |
print(f"Upregulated: {up_unshrunk}") | |
print(f"Downregulated: {down_unshrunk}") | |
print("\nAfter shrinkage:") | |
print(f"Significant genes (padj < 0.05): {sig_genes_shrunk}") | |
print(f"Upregulated: {up_shrunk}") | |
print(f"Downregulated: {down_shrunk}") | |
except Exception as e: | |
print(f"Error in comparison {save_prefix}: {str(e)}") | |
# Run all comparisons | |
for modality in ['Lysate', 'SN']: | |
print(f"\nProcessing {modality} comparisons:") | |
for i, time1 in enumerate(timepoints): | |
for time2 in timepoints[i+1:]: | |
run_deseq2_comparison(time1, time2, modality) | |
print("\nAll comparisons completed!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment