Skip to content

Instantly share code, notes, and snippets.

@fennecinspace
Created November 19, 2024 15:58
Show Gist options
  • Save fennecinspace/76d7a70b63c809c6ada150fda3bef7cd to your computer and use it in GitHub Desktop.
Save fennecinspace/76d7a70b63c809c6ada150fda3bef7cd to your computer and use it in GitHub Desktop.
num_epochs = 1
import wandb
entity = "mohamedinspace"
project = "Recaptcha-Solver"
name = "ResNet10" # you can change the name of your runs
wb = wandb . init (
entity = entity ,
project = project ,
name = name ,
)
for epoch in range(num_epochs):
# Training phase
model.train()
running_loss = 0.0
correct = 0
total = 0
train_loader_tqdm = tqdm(
train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Training]"
)
for inputs, labels in train_loader_tqdm:
inputs, labels = inputs.to(device), labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item()
_, preds = torch.max(outputs, 1)
correct += (preds == labels).sum().item()
total += labels.size(0)
train_loader_tqdm.set_postfix(
loss=loss.item(), accuracy=correct / total
)
train_loss = running_loss / len(train_loader)
train_accuracy = correct / total
# Validation phase
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, preds = torch.max(outputs, 1)
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)
val_loss /= len(val_loader)
val_accuracy = val_correct / val_total
wb.log(
{
"Train Loss" : train_loss ,
"Train Accuracy" : train_accuracy ,
"Val Loss" : val_loss ,
"Val Accuracy" : val_accuracy ,
},
step = epoch + 1 ,
commit = True ,
sync = True ,
)
print(f"Epoch {epoch+1}/{num_epochs}, "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment