Created
February 11, 2025 10:16
-
-
Save MartGro/d981c12c5561b7a620d2fabc4406d9ad to your computer and use it in GitHub Desktop.
pyDeseq2 Snippet that seems to work
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pydeseq2 | |
from pydeseq2.dds import DeseqDataSet | |
from pydeseq2.ds import DeseqStats | |
import pandas as pd | |
import numpy as np | |
from itertools import combinations | |
from pathlib import Path | |
# Create output directory if it doesn't exist | |
Path("deseq_merged").mkdir(parents=True, exist_ok=True) | |
# Define time groups and their order for both Lysate and SN | |
sample_groups = { | |
'Lysate': { | |
'9h': ['Lysate1-9h', 'Lysate2-9h', 'Lysate3-9h'], | |
'11h45': ['Lysate4-11h45', 'Lysate5-11h45', 'Lysate6-11h45'], | |
'14h30': ['Lysate8-14h30', 'Lysate9-14h30'], | |
'17h30': ['Lysate10-17h30', 'Lysate11-17h30', 'Lysate12-17h30'], | |
'20h30': ['Lysate13-20h30', 'Lysate14-20h30', 'Lysate15-20h30'] | |
}, | |
'SN': { | |
'9h': ['SN1-9h', 'SN2-9h', 'SN3-9h'], | |
'11h45': ['SN4-11h45', 'SN5-11h45', 'SN6-11h45'], | |
'14h30': ['SN7-14h30', 'SN8-14h30', 'SN9-14h30'], | |
'17h30': ['SN10-17h30', 'SN11-17h30', 'SN12-17h30'], | |
'20h30': ['SN13-20h30', 'SN14-20h30', 'SN15-20h30'] | |
} | |
} | |
timepoints = list(sample_groups['Lysate'].keys()) | |
def run_deseq2_comparison(time1, time2, modality): | |
save_prefix = f"{modality}_{time1}_vs_{time2}" | |
print(f"\nProcessing {save_prefix}...") | |
# Create subset for comparison | |
samples_columns = sample_groups[modality][time1] + sample_groups[modality][time2] | |
counts_subset = gene_level_df_merged_filtered[samples_columns].copy() | |
# Transpose the counts matrix | |
counts_subset_t = counts_subset.transpose() | |
# Create metadata DataFrame | |
metadata_dict = { | |
'time': {sample: time1 for sample in sample_groups[modality][time1]} | |
} | |
metadata_dict['time'].update({sample: time2 for sample in sample_groups[modality][time2]}) | |
column_info = pd.DataFrame(metadata_dict) | |
# Preprocess count data | |
counts_subset_t = counts_subset_t.round().astype(int) | |
try: | |
# Create DeseqDataSet object with explicit reference level | |
dds = DeseqDataSet( | |
counts=counts_subset_t, | |
metadata=column_info, | |
design_factors=["time"], | |
refit_cooks=True, | |
ref_level=["time", time1] # Explicitly set reference level to time1 | |
) | |
# Run DESeq2 analysis | |
dds.deseq2() | |
# Create DESeqStats object | |
ds = DeseqStats( | |
dds, | |
contrast=["time", time2, time1], # time2 vs time1 | |
alpha=0.05, | |
cooks_filter=True, | |
independent_filter=True | |
) | |
# Run statistical tests | |
ds.run_wald_test() | |
ds._cooks_filtering() | |
ds._independent_filtering() | |
results = ds.summary() | |
# Get the correct coefficient name - should be time2_vs_time1 | |
coeff = f"time_{time2}_vs_{time1}" | |
# Save results before LFC shrinkage | |
pre_shrinkage_results = pd.DataFrame({ | |
'baseMean': ds.base_mean, | |
'log2FoldChange': ds.results_df['log2FoldChange'], | |
'lfcSE': ds.results_df['lfcSE'], | |
'pvalue': ds.p_values, | |
'padj': ds.padj | |
}) | |
# Calculate pre-shrinkage statistics | |
pre_significant = sum(ds.padj < 0.05) | |
pre_upregulated = sum((ds.padj < 0.05) & (ds.results_df['log2FoldChange'] > 0)) | |
pre_downregulated = sum((ds.padj < 0.05) & (ds.results_df['log2FoldChange'] < 0)) | |
# Create MA plot before shrinkage | |
ds.plot_MA(log=True, save_path=f"deseq_merged/ma_plot_{save_prefix}_pre_shrinkage.png") | |
# Print available coefficients for debugging | |
print("\nAvailable LFC coefficients:", ds.LFC.columns) | |
# Perform LFC shrinkage | |
ds.lfc_shrink(coeff=coeff) | |
# Create MA plot after shrinkage | |
ds.plot_MA(log=True, save_path=f"deseq_merged/ma_plot_{save_prefix}_post_shrinkage.png") | |
# Save results after LFC shrinkage | |
post_shrinkage_results = pd.DataFrame({ | |
'baseMean': ds.base_mean, | |
'log2FoldChange': ds.results_df['log2FoldChange'], | |
'lfcSE': ds.results_df['lfcSE'], | |
'pvalue': ds.p_values, | |
'padj': ds.padj | |
}) | |
# Calculate post-shrinkage statistics | |
post_significant = sum(ds.padj < 0.05) | |
post_upregulated = sum((ds.padj < 0.05) & (ds.results_df['log2FoldChange'] > 0)) | |
post_downregulated = sum((ds.padj < 0.05) & (ds.results_df['log2FoldChange'] < 0)) | |
# Save both results to separate files | |
pre_shrinkage_results.to_csv(f'deseq_merged/deseq2_{save_prefix}_pre_shrinkage.csv') | |
post_shrinkage_results.to_csv(f'deseq_merged/deseq2_{save_prefix}_post_shrinkage.csv') | |
# Print comprehensive statistics | |
print(f"\nResults for {save_prefix}:") | |
print("\nPre-shrinkage statistics:") | |
print(f"Total significant genes (padj < 0.05): {pre_significant}") | |
print(f"Upregulated: {pre_upregulated}") | |
print(f"Downregulated: {pre_downregulated}") | |
print("\nPost-shrinkage statistics:") | |
print(f"Total significant genes (padj < 0.05): {post_significant}") | |
print(f"Upregulated: {post_upregulated}") | |
print(f"Downregulated: {post_downregulated}") | |
except Exception as e: | |
print(f"Error in comparison {save_prefix}: {str(e)}") | |
# Run all comparisons | |
for modality in ['Lysate', 'SN']: | |
print(f"\nProcessing {modality} comparisons:") | |
for i, time1 in enumerate(timepoints): | |
for time2 in timepoints[i+1:]: | |
run_deseq2_comparison(time1, time2, modality) | |
print("\nAll comparisons completed!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment