I have a bunch of useful code snippets I use quite frequently. I log them all here.
Last active
September 4, 2021 12:50
-
-
Save Syzygianinfern0/8d31e98bed8c05384675d82907fbfa52 to your computer and use it in GitHub Desktop.
Quick Code Snippets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
if not os.path.exists(directory): | |
os.makedirs(directory) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data.dataloader import DataLoader | |
from tqdm import tqdm | |
def get_mean_and_std(dataset, device): | |
""" | |
Finds the mean and std of the dataset. | |
This is used for Normalization of the dataset. | |
:param dataset: dataset to calculate mean and std for | |
:param device: CUDA or CPU? | |
:return: tuple of mean and std | |
""" | |
dataloader = DataLoader(dataset, batch_size=4096) | |
channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0 | |
for data, _ in tqdm(dataloader): | |
data = data.to(device) | |
channels_sum += torch.mean(data, dim=[0, 2, 3]) | |
channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3]) | |
num_batches += 1 | |
mean = channels_sum / num_batches | |
std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5 | |
return mean, std |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
def remove_dataparallel_wrapper(state_dict): | |
""" | |
Converts a DataParallel model to a normal one by removing the "module." | |
wrapper in the module dictionary. | |
:param state_dict: a torch.nn.DataParallel state dictionary | |
:return: a torch.nn.Module state dictionary | |
""" | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k[7:] # remove 'module.' of DataParallel | |
new_state_dict[name] = v | |
return new_state_dict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
import numpy as np | |
import torch | |
from torch.backends import cudnn | |
def set_seed(seed): | |
""" | |
Seeds pretty much everything that can be. | |
:param seed: the seed number to be used | |
:return: None | |
""" | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
cudnn.deterministic = True | |
cudnn.benchmark = False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment