This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Get the eigenvectors | |
F, T, H, W = activations.shape | |
points = activations.reshape([F, T*H*W]).transpose() | |
A = construct_affinity_mat(points, sigmas=[6.0]) | |
eigvecs = wncuts(A, num_eigenvectors=100) | |
eigvecs_reshaped = eigvecs.reshape([T, H, W, -1]) | |
# frame_shape is [h, w, c] | |
frame_shape = frames[0].shape[:2] | |
target_vec = eigvecs_reshaped[0, 0, 0] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import numpy as np | |
import torch.utils.data as data | |
import torch | |
from torchvision.datasets.video_utils import VideoClips | |
class GymnasticsVideo(data.Dataset): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import os | |
import random | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torch.nn import functional as F | |
from torch.utils.data import IterableDataset | |
from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, \ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test(model, dataloader, config): | |
model.eval() | |
num_batches = config['num_test_batches'] | |
running_loss = 0.0 | |
running_accuracy = 0.0 | |
with torch.no_grad(): | |
with tqdm(dataloader, total=num_batches) as pbar: | |
for batch_idx, batch in enumerate(pbar): | |
train_inputs, train_targets = batch['train'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
H = 10 | |
kmeans = make_clusters(eigenvectors, num_clusters) | |
W = 18 | |
T = 128 | |
CD = 8 | |
CR = int(T // CD) | |
colors = [] | |
labels = kmeans.labels_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy.sparse as sps | |
from scipy.spatial.distance import pdist, squareform | |
def get_normalized_laplacian(sq_distances, sigma_features): | |
squareform_sq_distances = np.exp(-squareform(sq_distances) / (2 * sigma_features**2)) | |
lap = sps.csgraph.laplacian(squareform_sq_distances, normed=True, return_diag=False) | |
return lap | |
def get_square_distances(output): | |
# output: [512, 128, 10, 18] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import gin | |
from meta_dataset.data import config | |
from meta_dataset.data import dataset_spec as dataset_spec_lib | |
from meta_dataset.data import learning_spec | |
from meta_dataset.data import pipeline | |
import numpy as np | |
import tensorflow as tf | |
import torch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
batch_size = inputs.shape[0] | |
input_embs = sender_embedding(inputs) | |
inputs = input_embs.view(batch_size, num_digits * embedding_size_sender) | |
hx = torch.zeros(batch_size, num_lstm_sender) | |
cx = torch.zeros(batch_size, num_lstm_sender) | |
for num in range(num_binary_messages): | |
hx, cx = sender_cell(inputs, (hx, cx)) | |
output = sender_project(hx) | |
pre_logits = sender_out(output) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_graph(self): | |
label = tf.one_hot(self.batch, 10*self._config.num_digits) | |
self.label = tf.argmax(label, -1) | |
num_digits = self._config.num_digits | |
num_binary_messages = self._config.num_binary_messages | |
# Speaker | |
with tf.variable_scope("A1"): | |
weights = tf.get_variable("embeddings", shape=(10*num_digits, self._config.embedding_size), | |
dtype=tf.float32, initializer=tf.orthogonal_initializer) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
q_y = tf.contrib.distributions.RelaxedOneHotCategorical(tau, logits=a1_logits) | |
y = q_y.sample() | |
y_hard = tf.cast(tf.one_hot(tf.argmax(y, -1), output_size), y.dtype) | |
# append a zero out onto the back so that argmax doesn't use an incorrect indice. | |
one_hot = np.array([0]*(output_size - 1) + [1]).astype(np.float32) | |
concat_one_hot = tf.expand_dims(tf.expand_dims(tf.convert_to_tensor(one_hot), 0), 0) | |
concat_one_hot = tf.tile(concat_one_hot, tf.stack([tf.shape(y_hard)[0], 1, 1])) | |
concat_y_hard = tf.concat([y_hard, concat_one_hot], 1) |
NewerOlder