Skip to content

Instantly share code, notes, and snippets.

@irhum
Last active January 20, 2019 12:49
Show Gist options
  • Save irhum/ac7c27d746a4fd4b3c5a1ddac2dceb70 to your computer and use it in GitHub Desktop.
Save irhum/ac7c27d746a4fd4b3c5a1ddac2dceb70 to your computer and use it in GitHub Desktop.
def train(model, dataloaders, criterion, optimizer, device, out_name, dlib_models=None,
validate=True, validate_every=10, num_epochs=100):
if validate:
assert len(dataloaders) == 2
assert dlib_models is not None
# start at epoch 1, end at epoch num_epochs (inclusive)
for epoch in range(1, num_epochs+1):
# Training phase
trn_loss = train_epoch(model, dataloaders['train'], criterion, optimizer, device)
print("Epoch: ", epoch, "Train Loss:", trn_loss)
# Validation Phase
if validate and epoch % validate_every == 0:
val_loss, dists = val_epoch(model, dataloaders['val'], criterion, optimizer, device)
avg_dist = np.mean(dists)
print("Epoch: ", epoch, "Val Loss:", val_loss, "Average Distance:", avg_dist)
def train_epoch(model, trn_dataloader, criterion, optimizer, device):
model.train()
running_loss = 0
# A typical loop through the dataloader, PyTorch style
for x, y in trn_dataloader:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
outputs = model(x)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
running_loss += loss.item() * x.size(0)
trn_loss = running_loss/len(trn_dataloader.dataset)
return trn_loss
def val_epoch(model, val_dataloader, criterion, dlib_models, device):
model.eval()
dists = []
running_loss = 0
for x, y in val_dataloader:
x = x.to(device)
y = y.to(device)
outputs = model(x)
loss = criterion(outputs, y)
running_loss += loss.item() * x.size(0)
# PyTorch has outputs in dimension order (batch_size, channels, height, width)
# We permute it to (batch_size, height, width, channels)
photos = outputs.detach().permute(0, 2, 3, 1).cpu().numpy()
# converting from floating point ([0, 1]) to uint8 ([0, 255]) representation
photos = (photos*255).astype('uint8')
for input_emb, photo in zip(x, photos):
dists += [distance_metric(photo, input_emb, dlib_models)]
val_loss = running_loss/len(val_dataloader.dataset)
return val_loss, dists
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment