Skip to content

Instantly share code, notes, and snippets.

@Burntt
Created March 1, 2022 06:57
Show Gist options
  • Save Burntt/a3879b6e8326a0d17c6ae370bb8f3606 to your computer and use it in GitHub Desktop.
Save Burntt/a3879b6e8326a0d17c6ae370bb8f3606 to your computer and use it in GitHub Desktop.
# Train function
def train(fold, model, device, trainloader, optimizer, epoch):
model.train()
correct_train = 0
correct_this_batch_train = 0
total_train_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
train_loss = criterion(output, target.flatten())
train_loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Fold/Epoch: {}/{} [{}/{} ({:.0f}%)]\ttrain_loss: {:.6f}'.format(
fold,epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), train_loss.item()))
# Measure accuracy on train set
total_train_loss += train_loss.item()
_, y_pred_tags_train = torch.max(output, dim = 1)
correct_this_batch_train = y_pred_tags_train.eq(target.flatten().view_as(y_pred_tags_train))
correct_train += correct_this_batch_train.sum().item()
return correct_train, train_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment