Last active
May 14, 2020 01:07
-
-
Save 18alantom/28de9de8397cec3a3fe44098aa7a422a to your computer and use it in GitHub Desktop.
A set of functions to help update a pytorch optimizer's param groups wrt learning rate and to unlock last n layers of a pytorch model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Imports for Annotations | |
from torch.nn import Module | |
from torch.optim import Optimizer | |
from typing import Optional, Union, List | |
def get_layers(model:Module, rgrad:bool=False): | |
""" | |
Returns all layers of the model | |
that have no children ie will return the | |
contents of a sequential but not the Sequential | |
rgrad : return layers whose parameters `requires_grad` | |
""" | |
for l in model.modules(): | |
if len([*l.children()]) == 0: | |
params = [*l.parameters()] | |
if len(params) > 0: | |
if rgrad: | |
if params[0].requires_grad: | |
yield l | |
else: | |
continue | |
else: | |
yield l | |
def freeze(model:Module)-> None: | |
# Freeze all layers in the model | |
for params in model.parameters(): | |
params.requires_grad = False | |
def unfreeze(model:Module)-> None: | |
# Unfreeze all layers in the model | |
for params in model.parameters(): | |
params.requires_grad = True | |
def unf_last_n(model:Module, n:Optional[int]=None): | |
""" | |
Unfreeze last `n` parametric layers of the | |
model. | |
if `n is None` then all layers are unfrozen. | |
""" | |
# Freeze all the layers | |
freeze(model) | |
# Unfreeze only the required layers | |
if n is None: | |
unfreeze(model) | |
else: | |
layers = [*get_layers(model)][::-1][:n] | |
for layer in layers: | |
unfreeze(layer) | |
def get_lrs(lr:slice, count:Optional[int]=None): | |
""" | |
Exponentially increasing lr from | |
slice.start to slice.stop. | |
if `count is None` then count = int(stop/start) | |
""" | |
lr1 = lr.start | |
lr2 = lr.stop | |
if count is None: | |
count = int(lr2/lr1) | |
incr = np.exp((np.log(lr2/lr1)/(count-1))) | |
return [lr1*incr**i for i in range(count)] | |
def configure_optimizer(model:Module, optimizer:Optimizer, | |
lr:Optional[Union[List[float], slice, float]]=None, | |
unlock:Optional[Union[bool, int]]=None): | |
""" | |
model : a pytorch nn.Module whose params are to be optimized | |
optimizer : a pytorch optimizer whose paramgroups have to | |
be configured. | |
lr : If lr is a `slice` then spread the lrs exponentially | |
over all the unlocked layers of the neural networks. | |
unlock : If unlock is True unlock all the layers | |
else if unlock is number, unlock the last [unlock] layers | |
""" | |
pgdicts = [] | |
param_groups = optimizer.param_groups | |
for param_group in param_groups: | |
pgdict = {} | |
for key in param_group: | |
if key not in ['lr','initial_lr','params']: | |
pgdict[key] = param_group[key] | |
pgdicts.append(pgdict) | |
# If no learning rate set the same learning rate to | |
# al unlocked layers. | |
if lr is None: | |
lr = param_groups[0]['lr'] | |
optimizer.param_groups.clear() | |
layers = [*get_layers(model, True)] | |
for i,layer in enumerate(layers): | |
if len(layers) != len(param_group): | |
i = 0 | |
optimizer.add_param_group({ | |
'params':layer.parameters(), | |
'lr':lr, | |
'initial_lr':lr, | |
**pgdicts[i] | |
}) | |
# If lr is not None apply slice | |
else: | |
optimizer.param_groups.clear() | |
if unlock is not None: | |
if unlock is True: | |
# Unfreeze all the layers | |
unf_last_n(model) | |
else: | |
# Unlock only the last n layers | |
unf_last_n(model, n=unlock) | |
# Attach learning rate to the unfrozen layers. | |
layers = [*get_layers(model,True)] | |
l = len(layers) | |
if isinstance(lr, slice): | |
lrs = get_lrs(lr, count=l) | |
elif isinstance(lr, list): | |
llay = len(layers) | |
llrs = len(lr) | |
if llrs < llay: | |
print("insufficient lrs") | |
return | |
d = llrs - llay | |
lrs = lr[d:] | |
else: | |
lrs = [lr] * l | |
for i,(lr, layer) in enumerate(zip(lrs, layers)): | |
if len(layers) != len(param_group): | |
i = 0 | |
optimizer.add_param_group({ | |
'params':layer.parameters(), | |
'lr':lr, | |
'initial_lr':lr, | |
**pgdicts[i] | |
}) | |
def print_lr_layer(model:Module, optimizer:Optimizer): | |
""" | |
Function to print lrs : layer | |
""" | |
layers = [*get_layers(model, True)] | |
pgroup = optimizer.param_groups | |
if len(pgroup) != len(layers): | |
if len(pgroup) > 1: | |
print('param_group, unfrozen layer length mismatch') | |
elif len(pgroup) == 1: | |
lr = pgroup[0]['lr'] | |
print(f"lr: {lr}, for: ") | |
for l in layers: | |
print(l) | |
else: | |
for pg, layer in zip(pgroup, layers): | |
lr = pg['lr'] | |
print(f"lr: {lr:0.10f} :: {layer}") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment