Last active
January 5, 2024 17:31
-
-
Save hminle/bc1a3dea64e42f8dc90c2cd617f71f6f to your computer and use it in GitHub Desktop.
DeSurv
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import torch.nn as nn | |
from nfg.nfg_torch import * | |
from nfg.nfg_api import NeuralFineGray | |
class DeSurv(NeuralFineGray): | |
def _gen_torch_model(self, inputdim, optimizer, risks): | |
self.loss = losses.total_loss | |
model = DeSurvTorch(inputdim, **self.params, | |
risks = risks, | |
optimizer = optimizer).double() | |
if self.cuda > 0: | |
model = model.cuda() | |
return model | |
def predict_survival(self, x, t, risk = None): | |
x = self._preprocess_test_data(x) | |
if not isinstance(t, list): | |
t = [t] | |
if self.fitted: | |
scores = [] | |
for t_ in t: | |
t_ = torch.DoubleTensor([t_] * len(x)).to(x.device) | |
pred, _, _ = self.torch_model(x, t_) | |
if risk is None: | |
scores.append(1 - pred.sum(1).unsqueeze(1).detach().cpu().numpy()) | |
else: | |
scores.append(1 - pred[:, int(risk) - 1].unsqueeze(1).detach().cpu().numpy()) | |
return np.concatenate(scores, axis = 1) | |
else: | |
raise Exception("The model has not been fitted yet. Please fit the " + | |
"model using the `fit` method on some training data " + | |
"before calling `predict_survival`.") | |
class CondODENet(nn.Module): | |
""" | |
Code extracted from https://github.com/djdanks/DeSurv | |
""" | |
def __init__(self, cov_dim, layers, output_dim, | |
act = "ReLU", n = 15): | |
super().__init__() | |
self.output_dim = output_dim | |
self.f = nn.Sequential(*create_representation(cov_dim + 1, layers + [output_dim], act, last = nn.Softplus())) | |
self.n = n | |
u_n, w_n = np.polynomial.legendre.leggauss(n) | |
self.u_n = nn.Parameter(torch.tensor(u_n, dtype = torch.float32)[None, :], requires_grad = False) | |
self.w_n = nn.Parameter(torch.tensor(w_n, dtype = torch.float32)[None, :], requires_grad = False) | |
def forward(self, x, horizon): | |
tau = torch.matmul(horizon.unsqueeze(-1) / 2., 1 + self.u_n) # N x n (+ 1 to push integral in 0 2 and /2 to push in 0 - t) | |
tau_ = torch.flatten(tau).unsqueeze(-1) # Nn x 1. Think of as N n-dim vectors stacked on top of each other | |
reppedx = torch.repeat_interleave(x, self.n, dim = 0) | |
taux = torch.cat((tau_, reppedx), 1) # Nn x (d+1) | |
f_n = self.f(taux).reshape((len(x), self.n, self.output_dim)) # N x n x d_out | |
pred = horizon.unsqueeze(-1) / 2. * ((self.w_n[:, :, None] * f_n).sum(dim = 1)) | |
return torch.tanh(pred) | |
class DeSurvTorch(nn.Module): | |
def __init__(self, inputdim, layers = [100, 100, 100], act = 'ReLU', layers_surv = [100], | |
risks = 1, optimizer = "Adam", n = 15): | |
super().__init__() | |
self.input_dim = inputdim | |
self.risks = risks # Competing risks | |
self.optimizer = optimizer | |
self.balance = nn.Sequential(*create_representation(inputdim, layers + [risks], act, last = nn.Softmax(dim = 1))) # Balance between risks | |
self.odenet = CondODENet(inputdim, layers_surv, risks, act, n = n) | |
def forward(self, x, horizon): | |
balance = self.balance(x) | |
Fr = self.odenet(x, horizon) | |
return balance * Fr, balance, Fr | |
def total_loss(model, x, t, e, eps = 1e-10): | |
pred, balance, ode = model.forward(x, t) | |
# Likelihood error | |
error = - torch.log(1 - pred[e == 0].sum(dim = 1) + eps).sum() | |
for k in range(model.risks): | |
ids = (e == (k + 1)) | |
derivative = model.odenet.f(torch.cat((t[ids].unsqueeze(1), x[ids]), 1)) | |
error -= (torch.log(1 - ode[ids][:, k] ** 2 + eps) | |
+ torch.log(derivative[:, k] + eps) | |
+ torch.log(balance[ids][:, k] + eps)).sum() | |
return error / len(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment