
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import string | |
import argparse | |
import time | |
import ray | |
import pandas as pd | |
import numpy as np | |
import uuid | |
import os |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os.path | |
import psutil | |
import pyarrow as pa | |
import numpy as np | |
from pyarrow import parquet as pq | |
import time | |
WINDOW_LENGTH = 1000 | |
N = 1000000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
bodies = torch.zeros((2, 1, 7, 7)) | |
heads = torch.zeros((2, 1, 7, 7)) | |
num_envs = bodies.size(0) | |
# Initialise body as shown in diagram | |
bodies[:, :, 3, 2] = 1 | |
bodies[:, :, 3, 3] = 2 | |
bodies[:, :, 2, 3] = 3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
movement_filters = torch.Tensor([ | |
[ | |
[0, 1, 0], | |
[0, 0, 0], | |
[0, 0, 0], | |
], | |
[ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchvision import transforms, datasets | |
from torch import nn, optim | |
from torch.utils.data import DataLoader | |
import torch.nn.functional as F | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, kernel_size=5) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def projected_gradient_descent(model, x, y, loss_fn, num_steps, step_size, step_norm, eps, eps_norm, | |
clamp=(0,1), y_target=None): | |
"""Performs the projected gradient descent attack on a batch of images.""" | |
x_adv = x.clone().detach().requires_grad_(True).to(x.device) | |
targeted = y_target is not None | |
num_channels = x.shape[1] | |
for i in range(num_steps): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
def replace_grad(parameter_gradients, parameter_name): | |
"""Creates a backward hook function that replaces the calculated gradient | |
with a precomputed value when .backward() is called. | |
See | |
https://pytorch.org/docs/stable/autograd.html?highlight=hook#torch.Tensor.register_hook | |
for more info |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def proto_net_episode(model: Module, | |
optimiser: Optimizer, | |
loss_fn: Callable, | |
x: torch.Tensor, | |
y: torch.Tensor, | |
n_shot: int, | |
k_way: int, | |
q_queries: int, | |
distance: str, | |
train: bool): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn.utils import clip_grad_norm_ | |
def matching_net_episode(model: Module, | |
optimiser: Optimizer, | |
loss_fn: Loss, | |
x: torch.Tensor, | |
y: torch.Tensor, | |
n_shot: int, | |
k_way: int, |
NewerOlder