Skip to content

Instantly share code, notes, and snippets.

@tmabraham
Last active April 18, 2022 08:31
Show Gist options
  • Save tmabraham/139d2012532b0361cf43a137db421614 to your computer and use it in GitHub Desktop.
Save tmabraham/139d2012532b0361cf43a137db421614 to your computer and use it in GitHub Desktop.
from timm import create_model
from fastai.vision.learner import _add_norm
class TimmBody(nn.Module):
def __init__(self, arch:str, pretrained:bool=True, cut=None, n_in:int=3):
super().__init__()
model = create_model(arch, pretrained=pretrained, num_classes=0, in_chans=n_in)
if isinstance(cut, int): self.model = nn.Sequential(*list(model.children())[:cut])
elif callable(cut): self.model = cut(model)
elif cut is None: self.model = model
else: raise NamedError("cut must be either integer or function")
self.cut = cut
self.need_to_pool = True if self.model.default_cfg.get('pool_size', None) else False
def forward(self,x):
if self.need_to_pool: return self.model.forward_features(x)
else: return self.model(x)
def create_head(nf, n_out, lin_ftrs=None, ps=0.5, pool=True, concat_pool=True, first_bn=True, bn_final=False,
lin_first=False, y_range=None):
"Model head that takes `nf` features, runs through `lin_ftrs`, and out `n_out` classes."
if pool and concat_pool: nf *= 2
lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]
bns = [first_bn] + [True]*len(lin_ftrs[1:])
ps = L(ps)
if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
layers = []
if pool:
pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)
layers += [pool, Flatten()]
if lin_first: layers.append(nn.Dropout(ps.pop(0)))
for ni,no,bn,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], bns, ps, actns):
layers += LinBnDrop(ni, no, bn=bn, p=p, act=actn, lin_first=lin_first)
if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))
if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
if y_range is not None: layers.append(SigmoidRange(*y_range))
return nn.Sequential(*layers)
def create_timm_model(arch:str, n_out, cut=None, pretrained=True, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,
concat_pool=True, **kwargs):
"Create custom architecture using `arch`, `n_in` and `n_out` from the `timm` library"
body = TimmBody(arch, pretrained, None, n_in)
if custom_head is None:
head = create_head(body.model.num_features, n_out, concat_pool=concat_pool, pool=body.need_to_pool, **kwargs)
else: head = custom_head
model = nn.Sequential(body, head)
if init is not None: apply_init(model[1], init)
return model
def timm_learner(dls, arch:str, loss_func=None, pretrained=True, cut=None, splitter=None,
y_range=None, config=None, n_out=None, normalize=True, **kwargs):
"Build a convnet style learner from `dls` and `arch` using the `timm` library"
if config is None: config = {}
if n_out is None: n_out = get_c(dls)
assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
if y_range is None and 'y_range' in config: y_range = config.pop('y_range')
model = create_timm_model(arch, n_out, default_split, pretrained, y_range=y_range, **config)
learn = Learner(dls, model, loss_func=loss_func, splitter=default_split, **kwargs)
if pretrained: learn.freeze()
return learn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment