Created
February 26, 2021 19:09
-
-
Save crowsonkb/e2f7a829b06fd74c5b7a7414ab015265 to your computer and use it in GitHub Desktop.
Biased EMA for PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Exponential moving average for PyTorch. Adapted from | |
https://www.zijianhu.com/post/pytorch/ema/. | |
""" | |
from copy import deepcopy | |
import torch | |
from torch import nn | |
class EMA(nn.Module): | |
def __init__(self, model, decay): | |
super().__init__() | |
self.model = model | |
self.decay = decay | |
self.average = deepcopy(self.model) | |
for param in self.average.parameters(): | |
param.detach_() | |
@torch.no_grad() | |
def update(self): | |
if not self.training: | |
raise RuntimeError('Update should only be called during training') | |
model_params = dict(self.model.named_parameters()) | |
average_params = dict(self.average.named_parameters()) | |
assert model_params.keys() == average_params.keys() | |
for name, param in model_params.items(): | |
average_params[name].mul_(self.decay) | |
average_params[name].add_((1 - self.decay) * param) | |
model_buffers = dict(self.model.named_buffers()) | |
average_buffers = dict(self.average.named_buffers()) | |
assert model_buffers.keys() == average_buffers.keys() | |
for name, buffer in model_buffers.items(): | |
average_buffers[name].copy_(buffer) | |
def forward(self, *args, **kwargs): | |
if self.training: | |
return self.model(*args, **kwargs) | |
return self.average(*args, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment