Skip to content

Instantly share code, notes, and snippets.

@MartGro
Created February 11, 2025 10:16
Show Gist options
  • Save MartGro/d981c12c5561b7a620d2fabc4406d9ad to your computer and use it in GitHub Desktop.
Save MartGro/d981c12c5561b7a620d2fabc4406d9ad to your computer and use it in GitHub Desktop.
pyDeseq2 Snippet that seems to work
import pydeseq2
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats
import pandas as pd
import numpy as np
from itertools import combinations
from pathlib import Path
# Create output directory if it doesn't exist
Path("deseq_merged").mkdir(parents=True, exist_ok=True)
# Define time groups and their order for both Lysate and SN
sample_groups = {
'Lysate': {
'9h': ['Lysate1-9h', 'Lysate2-9h', 'Lysate3-9h'],
'11h45': ['Lysate4-11h45', 'Lysate5-11h45', 'Lysate6-11h45'],
'14h30': ['Lysate8-14h30', 'Lysate9-14h30'],
'17h30': ['Lysate10-17h30', 'Lysate11-17h30', 'Lysate12-17h30'],
'20h30': ['Lysate13-20h30', 'Lysate14-20h30', 'Lysate15-20h30']
},
'SN': {
'9h': ['SN1-9h', 'SN2-9h', 'SN3-9h'],
'11h45': ['SN4-11h45', 'SN5-11h45', 'SN6-11h45'],
'14h30': ['SN7-14h30', 'SN8-14h30', 'SN9-14h30'],
'17h30': ['SN10-17h30', 'SN11-17h30', 'SN12-17h30'],
'20h30': ['SN13-20h30', 'SN14-20h30', 'SN15-20h30']
}
}
timepoints = list(sample_groups['Lysate'].keys())
def run_deseq2_comparison(time1, time2, modality):
save_prefix = f"{modality}_{time1}_vs_{time2}"
print(f"\nProcessing {save_prefix}...")
# Create subset for comparison
samples_columns = sample_groups[modality][time1] + sample_groups[modality][time2]
counts_subset = gene_level_df_merged_filtered[samples_columns].copy()
# Transpose the counts matrix
counts_subset_t = counts_subset.transpose()
# Create metadata DataFrame
metadata_dict = {
'time': {sample: time1 for sample in sample_groups[modality][time1]}
}
metadata_dict['time'].update({sample: time2 for sample in sample_groups[modality][time2]})
column_info = pd.DataFrame(metadata_dict)
# Preprocess count data
counts_subset_t = counts_subset_t.round().astype(int)
try:
# Create DeseqDataSet object with explicit reference level
dds = DeseqDataSet(
counts=counts_subset_t,
metadata=column_info,
design_factors=["time"],
refit_cooks=True,
ref_level=["time", time1] # Explicitly set reference level to time1
)
# Run DESeq2 analysis
dds.deseq2()
# Create DESeqStats object
ds = DeseqStats(
dds,
contrast=["time", time2, time1], # time2 vs time1
alpha=0.05,
cooks_filter=True,
independent_filter=True
)
# Run statistical tests
ds.run_wald_test()
ds._cooks_filtering()
ds._independent_filtering()
results = ds.summary()
# Get the correct coefficient name - should be time2_vs_time1
coeff = f"time_{time2}_vs_{time1}"
# Save results before LFC shrinkage
pre_shrinkage_results = pd.DataFrame({
'baseMean': ds.base_mean,
'log2FoldChange': ds.results_df['log2FoldChange'],
'lfcSE': ds.results_df['lfcSE'],
'pvalue': ds.p_values,
'padj': ds.padj
})
# Calculate pre-shrinkage statistics
pre_significant = sum(ds.padj < 0.05)
pre_upregulated = sum((ds.padj < 0.05) & (ds.results_df['log2FoldChange'] > 0))
pre_downregulated = sum((ds.padj < 0.05) & (ds.results_df['log2FoldChange'] < 0))
# Create MA plot before shrinkage
ds.plot_MA(log=True, save_path=f"deseq_merged/ma_plot_{save_prefix}_pre_shrinkage.png")
# Print available coefficients for debugging
print("\nAvailable LFC coefficients:", ds.LFC.columns)
# Perform LFC shrinkage
ds.lfc_shrink(coeff=coeff)
# Create MA plot after shrinkage
ds.plot_MA(log=True, save_path=f"deseq_merged/ma_plot_{save_prefix}_post_shrinkage.png")
# Save results after LFC shrinkage
post_shrinkage_results = pd.DataFrame({
'baseMean': ds.base_mean,
'log2FoldChange': ds.results_df['log2FoldChange'],
'lfcSE': ds.results_df['lfcSE'],
'pvalue': ds.p_values,
'padj': ds.padj
})
# Calculate post-shrinkage statistics
post_significant = sum(ds.padj < 0.05)
post_upregulated = sum((ds.padj < 0.05) & (ds.results_df['log2FoldChange'] > 0))
post_downregulated = sum((ds.padj < 0.05) & (ds.results_df['log2FoldChange'] < 0))
# Save both results to separate files
pre_shrinkage_results.to_csv(f'deseq_merged/deseq2_{save_prefix}_pre_shrinkage.csv')
post_shrinkage_results.to_csv(f'deseq_merged/deseq2_{save_prefix}_post_shrinkage.csv')
# Print comprehensive statistics
print(f"\nResults for {save_prefix}:")
print("\nPre-shrinkage statistics:")
print(f"Total significant genes (padj < 0.05): {pre_significant}")
print(f"Upregulated: {pre_upregulated}")
print(f"Downregulated: {pre_downregulated}")
print("\nPost-shrinkage statistics:")
print(f"Total significant genes (padj < 0.05): {post_significant}")
print(f"Upregulated: {post_upregulated}")
print(f"Downregulated: {post_downregulated}")
except Exception as e:
print(f"Error in comparison {save_prefix}: {str(e)}")
# Run all comparisons
for modality in ['Lysate', 'SN']:
print(f"\nProcessing {modality} comparisons:")
for i, time1 in enumerate(timepoints):
for time2 in timepoints[i+1:]:
run_deseq2_comparison(time1, time2, modality)
print("\nAll comparisons completed!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment