Last active
September 10, 2015 09:14
-
-
Save kgori/702678a72e745265ad55 to your computer and use it in GitHub Desktop.
Maximum likelihood distance between pairs of sequences
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def setup_logger(): | |
import logging | |
logger = logging.getLogger() | |
for handler in logger.handlers: | |
logger.removeHandler(handler) | |
ch=logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
ch.setFormatter(formatter) | |
logger.addHandler(ch) | |
logger.setLevel(logging.INFO) | |
return logger | |
logger = setup_logger() | |
dna_charmap = {'-': [1.0, 1.0, 1.0, 1.0], | |
'A': [0.0, 0.0, 1.0, 0.0], | |
'C': [0.0, 1.0, 0.0, 0.0], | |
'G': [0.0, 0.0, 0.0, 1.0], | |
'N': [1.0, 1.0, 1.0, 1.0], | |
'T': [1.0, 0.0, 0.0, 0.0], | |
'a': [0.0, 0.0, 1.0, 0.0], | |
'c': [0.0, 1.0, 0.0, 0.0], | |
'g': [0.0, 0.0, 0.0, 1.0], | |
'n': [1.0, 1.0, 1.0, 1.0], | |
't': [1.0, 0.0, 0.0, 0.0]} | |
# A R N D C Q E G H I L K M F P S T W Y V | |
protein_charmap = {'-': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], | |
'?': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], | |
'A': [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'R': [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'N': [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'D': [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'C': [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'Q': [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'E': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'G': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'H': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'I': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'L': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'K': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'M': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'F': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'P': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'S': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], | |
'T': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], | |
'W': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], | |
'Y': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], | |
'V': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], | |
'X': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], | |
'a': [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'r': [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'n': [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'd': [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'c': [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'q': [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'e': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'g': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'h': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'i': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'l': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'k': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'm': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'f': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
'p': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
's': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], | |
't': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], | |
'w': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], | |
'y': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], | |
'v': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], | |
'x': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]} | |
lg = np.array( | |
[[ 0. , 0.425093, 0.276818, 0.395144, 2.489084, 0.969894, 1.038545, 2.06604 , 0.358858, 0.14983 , 0.395337, 0.536518, 1.124035, 0.253701, 1.177651, 4.727182, 2.139501, 0.180717, 0.218959, 2.54787 ], | |
[ 0.425093, 0. , 0.751878, 0.123954, 0.534551, 2.807908, 0.36397 , 0.390192, 2.426601, 0.126991, 0.301848, 6.326067, 0.484133, 0.052722, 0.332533, 0.858151, 0.578987, 0.593607, 0.31444 , 0.170887], | |
[ 0.276818, 0.751878, 0. , 5.076149, 0.528768, 1.695752, 0.541712, 1.437645, 4.509238, 0.191503, 0.068427, 2.145078, 0.371004, 0.089525, 0.161787, 4.008358, 2.000679, 0.045376, 0.612025, 0.083688], | |
[ 0.395144, 0.123954, 5.076149, 0. , 0.062556, 0.523386, 5.24387 , 0.844926, 0.927114, 0.01069 , 0.015076, 0.282959, 0.025548, 0.017416, 0.394456, 1.240275, 0.42586 , 0.02989 , 0.135107, 0.037967], | |
[ 2.489084, 0.534551, 0.528768, 0.062556, 0. , 0.084808, 0.003499, 0.569265, 0.640543, 0.320627, 0.594007, 0.013266, 0.89368 , 1.105251, 0.075382, 2.784478, 1.14348 , 0.670128, 1.165532, 1.959291], | |
[ 0.969894, 2.807908, 1.695752, 0.523386, 0.084808, 0. , 4.128591, 0.267959, 4.813505, 0.072854, 0.582457, 3.234294, 1.672569, 0.035855, 0.624294, 1.223828, 1.080136, 0.236199, 0.257336, 0.210332], | |
[ 1.038545, 0.36397 , 0.541712, 5.24387 , 0.003499, 4.128591, 0. , 0.348847, 0.423881, 0.044265, 0.069673, 1.807177, 0.173735, 0.018811, 0.419409, 0.611973, 0.604545, 0.077852, 0.120037, 0.245034], | |
[ 2.06604 , 0.390192, 1.437645, 0.844926, 0.569265, 0.267959, 0.348847, 0. , 0.311484, 0.008705, 0.044261, 0.296636, 0.139538, 0.089586, 0.196961, 1.73999 , 0.129836, 0.268491, 0.054679, 0.076701], | |
[ 0.358858, 2.426601, 4.509238, 0.927114, 0.640543, 4.813505, 0.423881, 0.311484, 0. , 0.108882, 0.366317, 0.697264, 0.442472, 0.682139, 0.508851, 0.990012, 0.584262, 0.597054, 5.306834, 0.119013], | |
[ 0.14983 , 0.126991, 0.191503, 0.01069 , 0.320627, 0.072854, 0.044265, 0.008705, 0.108882, 0. , 4.145067, 0.159069, 4.273607, 1.112727, 0.078281, 0.064105, 1.033739, 0.11166 , 0.232523, 10.649107], | |
[ 0.395337, 0.301848, 0.068427, 0.015076, 0.594007, 0.582457, 0.069673, 0.044261, 0.366317, 4.145067, 0. , 0.1375 , 6.312358, 2.592692, 0.24906 , 0.182287, 0.302936, 0.619632, 0.299648, 1.702745], | |
[ 0.536518, 6.326067, 2.145078, 0.282959, 0.013266, 3.234294, 1.807177, 0.296636, 0.697264, 0.159069, 0.1375 , 0. , 0.656604, 0.023918, 0.390322, 0.748683, 1.136863, 0.049906, 0.131932, 0.185202], | |
[ 1.124035, 0.484133, 0.371004, 0.025548, 0.89368 , 1.672569, 0.173735, 0.139538, 0.442472, 4.273607, 6.312358, 0.656604, 0. , 1.798853, 0.099849, 0.34696 , 2.020366, 0.696175, 0.481306, 1.898718], | |
[ 0.253701, 0.052722, 0.089525, 0.017416, 1.105251, 0.035855, 0.018811, 0.089586, 0.682139, 1.112727, 2.592692, 0.023918, 1.798853, 0. , 0.094464, 0.361819, 0.165001, 2.457121, 7.803902, 0.654683], | |
[ 1.177651, 0.332533, 0.161787, 0.394456, 0.075382, 0.624294, 0.419409, 0.196961, 0.508851, 0.078281, 0.24906 , 0.390322, 0.099849, 0.094464, 0. , 1.338132, 0.571468, 0.095131, 0.089613, 0.296501], | |
[ 4.727182, 0.858151, 4.008358, 1.240275, 2.784478, 1.223828, 0.611973, 1.73999 , 0.990012, 0.064105, 0.182287, 0.748683, 0.34696 , 0.361819, 1.338132, 0. , 6.472279, 0.248862, 0.400547, 0.098369], | |
[ 2.139501, 0.578987, 2.000679, 0.42586 , 1.14348 , 1.080136, 0.604545, 0.129836, 0.584262, 1.033739, 0.302936, 1.136863, 2.020366, 0.165001, 0.571468, 6.472279, 0. , 0.140825, 0.245841, 2.188158], | |
[ 0.180717, 0.593607, 0.045376, 0.02989 , 0.670128, 0.236199, 0.077852, 0.268491, 0.597054, 0.11166 , 0.619632, 0.049906, 0.696175, 2.457121, 0.095131, 0.248862, 0.140825, 0. , 3.151815, 0.18951 ], | |
[ 0.218959, 0.31444 , 0.612025, 0.135107, 1.165532, 0.257336, 0.120037, 0.054679, 5.306834, 0.232523, 0.299648, 0.131932, 0.481306, 7.803902, 0.089613, 0.400547, 0.245841, 3.151815, 0. , 0.249313], | |
[ 2.54787 , 0.170887, 0.083688, 0.037967, 1.959291, 0.210332, 0.245034, 0.076701, 0.119013, 10.649107, 1.702745, 0.185202, 1.898718, 0.654683, 0.296501, 0.098369, 2.188158, 0.18951 , 0.249313, 0. ]]) | |
lg_freqs = np.array([0.079066, 0.055941, 0.041977, 0.053052, 0.012937, 0.040767, 0.071586, 0.057337, 0.022355, 0.062157, 0.099081, 0.064600, 0.022951, 0.042302, 0.044040, 0.061197, 0.053287, 0.012066, 0.034155, 0.069147]) | |
def seq_to_partials(seq, alphabet='dna'): | |
return np.array( | |
[(dna_charmap[char] if alphabet=='dna' else protein_charmap[char]) | |
for char in seq] | |
) | |
def get_q_matrix(rates, freqs): | |
q = rates.dot(np.diag(freqs)) | |
q.flat[::len(freqs)+1] -= q.sum(1) | |
q /= (-(np.diag(q)*freqs).sum()) | |
return q | |
def get_b_matrix(q_matrix, sqrtfreqs): | |
sqrtfreqs | |
return np.diag(sqrtfreqs).dot(q_matrix).dot(np.diag(1/sqrtfreqs)) | |
def get_eigen(q_matrix, freqs=None): | |
if freqs is not None: | |
rootf = np.sqrt(freqs) | |
mtx = get_b_matrix(q_matrix, rootf) | |
l, r = np.linalg.eigh(mtx) | |
u = np.diag(1/rootf).dot(r) | |
uinv = r.T.dot(np.diag(rootf)) | |
else: | |
mtx = q_matrix | |
l, u = np.linalg.eig(mtx) | |
sort_ix = np.argsort(l) | |
l = l[sort_ix] | |
u = u[:, sort_ix] | |
uinv = np.linalg.inv(u) | |
return u, l, uinv | |
class EigenDecomp(object): | |
def __init__(self, qmatrix, freqs=None): | |
u, l, uinv = get_eigen(qmatrix, freqs) | |
self.u = u | |
self.l = l | |
self.uinv = uinv | |
def values(self): | |
return self.u, self.l, self.uinv | |
class TransitionMatrix(object): | |
def __init__(self, eigen): | |
self.values = eigen.values() | |
self.size = len(self.values[1]) | |
def get_p_matrix(self, t): | |
""" | |
P = transition probabilities | |
""" | |
u, l, uinv = self.values | |
return u.dot(np.diag(np.exp(l*t))).dot(uinv) | |
def get_dp_matrix(self, t): | |
""" | |
First derivative of P | |
""" | |
u, l, uinv = self.values | |
return u.dot(np.diag(l*np.exp(l*t))).dot(uinv) | |
def get_d2p_matrix(self, t): | |
""" | |
Second derivative of P | |
""" | |
u, l, uinv = self.values | |
return u.dot(np.diag(l*l*np.exp(l*t))).dot(uinv) | |
class PairLikelihood(object): | |
""" | |
See Yang, (2000) "Maximum Likelihood Estimation on Large Phylogenies and Analysis of | |
Adaptive Evolution in Human Influenza Virus A", J. Mol. Evol. | |
""" | |
def __init__(self, transmat, edgelen): | |
self.transmat = transmat | |
self.update_transmat(edgelen) | |
def update_transmat(self, edgelen): | |
""" Update transition probabilities for new branch lengths """ | |
self.p = self.transmat.get_p_matrix(edgelen) | |
self.dp = self.transmat.get_dp_matrix(edgelen) | |
self.d2p = self.transmat.get_d2p_matrix(edgelen) | |
def _f(self, partial_a, partial_b, freqs): | |
return (freqs*self.p*partial_a*partial_b[:,np.newaxis]).sum() | |
def _df(self, partial_a, partial_b, freqs): | |
return (freqs*self.dp*partial_a*partial_b[:,np.newaxis]).sum() | |
def _d2f(self, partial_a, partial_b, freqs): | |
return (freqs*self.d2p*partial_a*partial_b[:,np.newaxis]).sum() | |
def calculate(self, partials_a, partials_b, freqs): | |
lnl = 0 | |
dlnl = 0 | |
d2lnl = 0 | |
for sa, sb in zip(sites_a, sites_b): | |
f = self._f(sa, sb, freqs) | |
df = self._df(sa, sb, freqs) | |
d2f = self._d2f(sa, sb, freqs) | |
lnl += np.log(f) | |
dlnl += df / f | |
d2lnl += (f * d2f - df * df) / (f * f) | |
return lnl, dlnl, d2lnl | |
class Likelihood(object): | |
""" | |
See Yang, (2000) "Maximum Likelihood Estimation on Large Phylogenies and Analysis of | |
Adaptive Evolution in Human Influenza Virus A", J. Mol. Evol. | |
""" | |
def __init__(self, transmat, edgelen_left, edgelen_right=0): | |
""" Initialise object with TransitionMatrix and a branch lengths. | |
""" | |
self.transmat = transmat | |
self.update_transmat(edgelen_left, edgelen_right) | |
self.size = transmat.size | |
def update_transmat(self, edgelen_left, edgelen_right=0): | |
""" Update transition probabilities for new branch lengths """ | |
self.p_left = self.transmat.get_p_matrix(edgelen_left) | |
self.p_right = self.transmat.get_p_matrix(edgelen_right) | |
self.dp = self.transmat.get_dp_matrix(edgelen_left+edgelen_right) | |
self.d2p = self.transmat.get_d2p_matrix(edgelen_left+edgelen_right) | |
def _likvec(self, partial_a, partial_b): | |
""" Calculate the likelihood vector for a site """ | |
return (self.p_left*partial_a).sum(1) * (self.p_right*partial_b).sum(1) | |
def _f(self, partial_a, partial_b, freqs): | |
""" Calculate the root likelihood for a site """ | |
vec = self._likvec(partial_a, partial_b) | |
return (vec*freqs).sum() | |
def _df(self, partial_a, partial_b, freqs): | |
""" Calculate the root first derivative of the likelihood for a site """ | |
return (freqs*self.dp*partial_a*partial_b[:,np.newaxis]).sum() | |
def _d2f(self, partial_a, partial_b, freqs): | |
""" Calculate the root second derivative of the likelihood for a site """ | |
return (freqs*self.d2p*partial_a*partial_b[:,np.newaxis]).sum() | |
def calculate(self, sites_a, sites_b, pi): | |
""" Calculate log likelihood and first and second derivatives over all sites. | |
Sites need to be given as partials (i.e. conditional probability vectors) """ | |
lnl = 0 | |
dlnl = 0 | |
d2lnl = 0 | |
for sa, sb in zip(sites_a, sites_b): | |
f = self._f(sa, sb, pi) | |
df = self._df(sa, sb, pi) | |
d2f = self._d2f(sa, sb, pi) | |
lnl += np.log(f) | |
dlnl += df / f | |
d2lnl += (f * d2f - df * df) / (f * f) | |
return lnl, dlnl, d2lnl | |
class OptWrapper(object): | |
""" | |
Wrapper for use with scipy optimiser (e.g. brenth/brentq) | |
""" | |
def __init__(self, likelihood, sites_a, sites_b, freqs): | |
self.lik = likelihood | |
self.sites_a = sites_a | |
self.sites_b = sites_b | |
self.freqs = freqs | |
self.updated = None | |
def update(self, brlen): | |
if self.updated == brlen: | |
return | |
else: | |
self.updated = brlen | |
self.lik.update_transmat(brlen) | |
self.lnl, self.dlnl, self.d2lnl = self.lik.calculate(self.sites_a, self.sites_b, self.freqs) | |
def get_dlnl(self, brlen): | |
self.update(brlen) | |
return self.dlnl | |
def get_d2lnl(self, brlen): | |
self.update(brlen) | |
return self.d2lnl | |
def __str__(self): | |
return 'Branch length={}, Variance={}, Likelihood+derivatives = {} {} {}'.format(self.updated, -1/self.d2lnl, self.lnl, self.dlnl, self.d2lnl) | |
def optimise(likelihood, partials_a, partials_b, frequencies, min_brlen=0.00001, max_brlen=10, verbose=True): | |
""" | |
Optimise ML distance between two partials. min and max set brackets | |
likelihood = Likelihood or PairLikelihood object. PairLikelihood is slightly faster. | |
""" | |
from scipy.optimize import brenth | |
wrapper = OptWrapper(likelihood, partials_a, partials_b, frequencies) | |
brlen = 0.5 | |
n=brenth(wrapper.get_dlnl, min_brlen, max_brlen) | |
if verbose: | |
print(wrapper) | |
return n, -1/wrapper.get_d2lnl(n) | |
if __name__ == '__main__': | |
kappa = 1 | |
k80 = np.array([[0,kappa,1,1],[kappa,0,1,1],[1,1,0,kappa],[1,1,kappa,0]], dtype=np.float) | |
k80f = np.array([0.25,0.25,0.25,0.25]) | |
# Simulated data from K80, kappa=1, distance = 0.8 | |
sites_a = seq_to_partials('ACCCTCCGCGTTGGGTAGTCCTAGGCCCAATGGCGTTTATGCCTCGATTTTTAGTTCTACCGTCCCTACAGATGGATGCCGTCGCATAGACACTGTCAATTCCATTCGGCAGGCTTCACACTGTTGCATTTTCATTTTGTACACGGTACCAACATAGGAGTGCTGTATTGCTATATTTCCAGTACACGGCGTTGAGTCGGATGGAAACGCCGGCGGAAGACAGCTTGGCGGGTCTTCACGCATCACCGCGGGGTCTGAAAGGTATTATCGCTGCTTAAATCAGACCGGTCAAGCTTCCTGGCGGAAGGCGGCAAGGTCCAGCCACAGCATGCTTATTCCTTGTCACGCCGGGTGGAAATCTAGAGCGTCCGGTGGACACAGAGTGATTTTGTACGGGGGGTTCCATACCAGGACATTAGGGTCGGTTTACGGTCTGAGATGTATGTTGCCTTGCGGTCGACGAGCACTGATTCCCCTGAACTTCGTAAGACACATATAGTTTTAATGAAATCCCCAAAACGAGCATGGTTTCAGTATACGCGACAACTTAGGATACAACATACTGAACCAGTCCGCATTGAGGTGCCAATCAAACGGGACCGGGACTGATAAGTATAAAATAGGTTTCCCTGTCCTCTACCTACGTTATCCTCGCGTCGATTTTGATTCTTACCAAGACTGCTAATCAGGCCCTGTGGCCTGCATGTCACCATGTCAGCGTGTTTGGCTAAATTCACGGGATTGGCCTTACCGACTTACATCAGTATTTCATACATAGTTACTCGAGTTTAACGTTGACAGTTAGTCCCATGATACGGCAAAGCCTGGTTCGGCGGATTTCCGAGTACAGCATCTTCGCCCCCGAGATTGCCGCCAATGGACACCCTCCTGAGATGCAGATATGAGTGTTTTTGACACTCTGAGGCTGAGATCCTCACACTTCCGGAGCTTCCGCGATAGTCACGTGGTTATTAGACTTACGGCAGGAAAAATCATGTTA', alphabet='dna') | |
sites_b = seq_to_partials('AAGCTCCGCGTAAGCTAACGACCAGTCAGCTAGGTTTAGTGCCACCAGTATGGCTAGTTCCGGAGGGCAAACCGGATGCTACCGATTGGTCACCCTCAGGGTGATTTCGCAGGGCGCTCACTTATTCCTTTTAAATCCTGCCAACAGACTAAGAAAGTTGTACGGTATTCCTATATCTTCAGTACTGCTCTTGGCCGTGCATGTAGCCGAACGACGAGGACGGTACATGAGTTTCTCACCAATTACAGGCGGTTCCATTAGGCAGTAGCTGCGGTTAGTTCATACTGCTAAAGAATCTTCTTGGAACGTGCCAAGGACCAGTCACACACATGTTGTAGTCCCTCATCGTGGTAGGCGTTCCAGACCGTCCGTGGTACACATACCAAATTTCGTACCGGCTGACTCAAAGCGGGAGTTCGCATGATACCAGGGAACGAGATGTTCAAAACGATCAGGTAGTGCCGCCATCTTTCAGGTTCTTTCGTTTCGTCCTATGATACTTGAGTAGCGGTCAAACGAAGCTCGTAGGTGACAGTTACGAGACATGCTGGGATGCAACATACTTTCGCAGTTAGCTAGTAGGTACCTATCTAGCGAATCGAGCTAGGATACCCTGATTATGCTTGTCTCCGTCCTCTTACTATGATCTCCTCGCGTGGTTTTTGCTGCTTAACCGTTGTGCCGTATAAAACAAGAGGCGGGAGTTTAGCTGTGGGAACTTCGTAGACCTTGTAAGCTGGATAGGCCCGTCCGTCGTAATTAATTACCTAAAAGAGAGTCAAACAAGCTTAAGTCGCCGAGTTAGTCGGATAAGAAGCCATTCTCTGGTCCGCCAACCTTCCCATGCCAGTACGGTTGCCGAGGTCCATTCGGTGACTGTGGGATAACCGTTGCCGGAGCTATGAGATCCATTACAACTCTGCGCCTAGGATGTTAACTCTACCGAAGTTTGCGACCCCGGAACCTGTAAATTGTCCTTAGGGTCGTAACATTTTCAAGC', alphabet='dna') | |
ed = EigenDecomp(get_q_matrix(k80, k80f), k80f) | |
tm = TransitionMatrix(ed) | |
lk = PairLikelihood(tm, 0) | |
optimise(lk, sites_a, sites_b, k80f) | |
# Example from Section 4.2 of Ziheng's book - his value for node 6 is wrong! | |
np.set_printoptions(precision=6) | |
partials_1 = np.array([1, 0, 0, 0], dtype=np.float) | |
partials_2 = np.array([0, 1, 0, 0], dtype=np.float) | |
partials_3 = np.array([0, 0, 1, 0], dtype=np.float) | |
partials_4 = np.array([0, 1, 0, 0], dtype=np.float) | |
partials_5 = np.array([0, 1, 0, 0], dtype=np.float) | |
lik_inner = Likelihood(tm, 0.1, 0.1) | |
lik_outer = Likelihood(tm, 0.2, 0.2) | |
lik_inner_outer = Likelihood(tm, 0.1, 0.2) | |
partials_7 = lik_outer._likvec(partials_1, partials_2) | |
partials_8 = lik_outer._likvec(partials_4, partials_5) | |
partials_6 = lik_inner_outer._likvec(partials_7, partials_3) | |
partials_0 = lik_inner._likvec(partials_6, partials_8) | |
for partial in [partials_0, partials_6, partials_8, partials_7]: | |
print partial |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment