Last active
May 23, 2025 19:49
-
-
Save inikishev/b639c2133abfaeb15a13e83d1a8efc94 to your computer and use it in GitHub Desktop.
unofficial and preliminary implementation, current issues - all computation happens during forward pass instead of backward pass, fisher is not implemented, only supports losses with (inputs, targets) signature and only calculates grad wrt inputs,
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections.abc import Iterable, Sequence | |
from typing import Literal | |
import torch | |
def jacobian(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False): | |
flat_input = torch.cat([i.reshape(-1) for i in input]) | |
return torch.autograd.grad( | |
flat_input, | |
wrt, | |
torch.eye(len(flat_input), device=input[0].device, dtype=input[0].dtype), | |
retain_graph=True, | |
create_graph=create_graph, | |
allow_unused=True, | |
is_grads_batched=True, | |
) | |
def make_newton_loss(loss_fn, tik_l: float = 1e-2): | |
class NewtonLoss(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, preds: torch.Tensor, targets: torch.Tensor): | |
with torch.enable_grad(): | |
# necessary to flatten preds FIRST so they are part of the graph | |
preds_flat = preds.ravel() | |
value = loss_fn(preds_flat.view_as(preds), targets) | |
# caluclate gradient and hessian | |
g = torch.autograd.grad(value, preds_flat, create_graph=True)[0] | |
H: torch.Tensor = jacobian([g], [preds_flat])[0] | |
# apply regularization | |
if tik_l != 0: | |
H.add_(torch.eye(H.size(0), device=H.device, dtype=H.dtype).mul_(tik_l)) | |
# newton step | |
newton_step, success = torch.linalg.solve_ex(H, g) | |
ctx.save_for_backward(newton_step.view_as(preds)) | |
return value | |
@staticmethod | |
def backward(ctx, *grad_outputs): | |
newton_step = ctx.saved_tensors[0] # inputs to loss | |
return newton_step, None | |
return NewtonLoss.apply | |
def make_batched_newton_loss(loss_fn, tik_l: float = 1e-2): | |
class BatchedNewtonLoss(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, preds: torch.Tensor, targets: torch.Tensor): | |
with torch.enable_grad(): | |
# necessary to flatten and unbind preds FIRST and then re-stack so they are part of the graph | |
preds_flat = preds.view(preds.size(0), -1) | |
samples = preds_flat.unbind(0) | |
value = loss_fn(torch.stack(samples).view_as(preds), targets) | |
# caluclate gradient and hessian | |
per_sample_H = [] | |
per_sample_g = [] | |
for sample in samples: | |
g = torch.autograd.grad(value, sample, create_graph=True,)[0] | |
H: torch.Tensor = jacobian([g], [sample])[0] | |
per_sample_g.append(g) | |
per_sample_H.append(H) | |
# stack | |
H = torch.stack(per_sample_H) | |
g = torch.stack(per_sample_g) | |
# apply regularization | |
if tik_l != 0: | |
I = torch.eye(H.size(-1), device=per_sample_H[0].device, dtype=per_sample_H[0].dtype).mul_(tik_l).unsqueeze(0) | |
H += I | |
# newton step | |
newton_step, success = torch.linalg.solve_ex(H, g) | |
ctx.save_for_backward(newton_step.view_as(preds)) | |
return value | |
@staticmethod | |
def backward(ctx, *grad_outputs): | |
newton_step = ctx.saved_tensors[0] # inputs to loss | |
return newton_step, None | |
return BatchedNewtonLoss.apply | |
if __name__ == "__main__": | |
from monai.losses import DiceFocalLoss | |
dice = DiceFocalLoss(softmax=True) | |
# dice = torch.nn.MSELoss() | |
input = torch.randn(32,100, device='cuda') | |
target = (torch.rand(32,100, device='cuda') > 0.5).float() | |
x = input.clone().requires_grad_(True) | |
opt = torch.optim.SGD([x], 1) | |
print('normal dice') | |
for i in range(100): | |
loss = dice(x, target) | |
mse = (x-target).pow(2).mean() | |
print(f'{i}, {loss = }, {mse = }') | |
opt.zero_grad() | |
loss.backward() | |
opt.step() | |
newton_dice = make_batched_newton_loss(DiceFocalLoss(softmax=True), tik_l=1e-2) | |
x = input.clone().requires_grad_(True) | |
opt = torch.optim.SGD([x], 1) | |
# this is slow on dice because it calculates 32 100x100 hessians | |
# its for other losses but I only have dice installed | |
print('newton dice') | |
for i in range(100): | |
loss = newton_dice(x, target) | |
mse = (x-target).pow(2).mean() | |
print(f'{i}, {loss = }, {mse = }') | |
opt.zero_grad() | |
loss.backward() | |
opt.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment