Last active
April 18, 2022 08:31
-
-
Save tmabraham/139d2012532b0361cf43a137db421614 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from timm import create_model | |
from fastai.vision.learner import _add_norm | |
class TimmBody(nn.Module): | |
def __init__(self, arch:str, pretrained:bool=True, cut=None, n_in:int=3): | |
super().__init__() | |
model = create_model(arch, pretrained=pretrained, num_classes=0, in_chans=n_in) | |
if isinstance(cut, int): self.model = nn.Sequential(*list(model.children())[:cut]) | |
elif callable(cut): self.model = cut(model) | |
elif cut is None: self.model = model | |
else: raise NamedError("cut must be either integer or function") | |
self.cut = cut | |
self.need_to_pool = True if self.model.default_cfg.get('pool_size', None) else False | |
def forward(self,x): | |
if self.need_to_pool: return self.model.forward_features(x) | |
else: return self.model(x) | |
def create_head(nf, n_out, lin_ftrs=None, ps=0.5, pool=True, concat_pool=True, first_bn=True, bn_final=False, | |
lin_first=False, y_range=None): | |
"Model head that takes `nf` features, runs through `lin_ftrs`, and out `n_out` classes." | |
if pool and concat_pool: nf *= 2 | |
lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out] | |
bns = [first_bn] + [True]*len(lin_ftrs[1:]) | |
ps = L(ps) | |
if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps | |
actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None] | |
layers = [] | |
if pool: | |
pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1) | |
layers += [pool, Flatten()] | |
if lin_first: layers.append(nn.Dropout(ps.pop(0))) | |
for ni,no,bn,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], bns, ps, actns): | |
layers += LinBnDrop(ni, no, bn=bn, p=p, act=actn, lin_first=lin_first) | |
if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out)) | |
if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01)) | |
if y_range is not None: layers.append(SigmoidRange(*y_range)) | |
return nn.Sequential(*layers) | |
def create_timm_model(arch:str, n_out, cut=None, pretrained=True, n_in=3, init=nn.init.kaiming_normal_, custom_head=None, | |
concat_pool=True, **kwargs): | |
"Create custom architecture using `arch`, `n_in` and `n_out` from the `timm` library" | |
body = TimmBody(arch, pretrained, None, n_in) | |
if custom_head is None: | |
head = create_head(body.model.num_features, n_out, concat_pool=concat_pool, pool=body.need_to_pool, **kwargs) | |
else: head = custom_head | |
model = nn.Sequential(body, head) | |
if init is not None: apply_init(model[1], init) | |
return model | |
def timm_learner(dls, arch:str, loss_func=None, pretrained=True, cut=None, splitter=None, | |
y_range=None, config=None, n_out=None, normalize=True, **kwargs): | |
"Build a convnet style learner from `dls` and `arch` using the `timm` library" | |
if config is None: config = {} | |
if n_out is None: n_out = get_c(dls) | |
assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`" | |
if y_range is None and 'y_range' in config: y_range = config.pop('y_range') | |
model = create_timm_model(arch, n_out, default_split, pretrained, y_range=y_range, **config) | |
learn = Learner(dls, model, loss_func=loss_func, splitter=default_split, **kwargs) | |
if pretrained: learn.freeze() | |
return learn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment