Last active
April 5, 2020 19:51
-
-
Save dlibenzi/63de84850bcae2c517876d8ce66f7b66 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class XlaLSTM(nn.Module): | |
def __init__(self, input_sz, hidden_sz, batch_first=False, pad_value=0): | |
super(XlaLSTM, self).__init__() | |
self.input_sz = input_sz | |
self.hidden_size = hidden_sz | |
if batch_first: | |
self.batch_dim, self.sequence_dim = 0, 1 | |
else: | |
self.batch_dim, self.sequence_dim = 1, 0 | |
self.pad_value = pad_value | |
self.weight_ih = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4)) | |
self.weight_hh = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4)) | |
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4)) | |
self.init_weights() | |
def init_weights(self): | |
for p in self.parameters(): | |
if p.data.ndimension() >= 2: | |
nn.init.xavier_uniform_(p.data) | |
else: | |
nn.init.zeros_(p.data) | |
def sequence_slice(self, t, embed_out, embed_in): | |
ot = embed_out[:, t, :] if self.sequence_dim == 1 else embed_out[t, :, :] | |
it = embed_in[:, t] if self.sequence_dim == 1 else embed_in[t, :] | |
return ot, it | |
def sizes(self, embed_out): | |
size = embed_out.size() | |
return size[self.batch_dim], size[self.sequence_dim] | |
def forward(self, embed_out, embed_in, init_states=None): | |
bs, seq_sz = self.sizes(embed_out) | |
if init_states is None: | |
h_t, c_t = (torch.zeros(bs, self.hidden_size, device=embed_out.device), | |
torch.zeros(bs, self.hidden_size, device=embed_out.device)) | |
else: | |
h_t, c_t = init_states | |
hstate = [] | |
HS = self.hidden_size | |
for t in range(0, seq_sz): | |
feat, iseq = self.sequence_slice(t, embed_out, embed_in) | |
gates = feat @ self.weight_ih + h_t @ self.weight_hh + self.bias | |
i_t, f_t, g_t, o_t = ( | |
torch.sigmoid(gates[:, :HS]), | |
torch.sigmoid(gates[:, HS:HS * 2]), | |
torch.tanh(gates[:, HS * 2:HS * 3]), | |
torch.sigmoid(gates[:, HS * 3:]), | |
) | |
fwd = iseq.unsqueeze(1) != self.pad_value | |
c_t = torch.where(fwd, f_t * c_t + i_t * g_t, c_t) | |
h_t = torch.where(fwd, o_t * torch.tanh(c_t), h_t) | |
hstate.append(h_t.unsqueeze(self.sequence_dim)) | |
hstate = torch.cat(hstate, dim=self.sequence_dim) | |
return hstate, (h_t, c_t) | |
#### TEST | |
import collections | |
import csv | |
import nltk | |
import random | |
import torch.optim as optim | |
import torch_xla | |
import torch_xla.core.xla_model as xm | |
import torch_xla.debug.metrics as met | |
import torch_xla.utils.gcsfs as gcs | |
def maybe_int(v): | |
try: | |
return int(v) | |
except ValueError: | |
pass | |
def read_dataset(path): | |
with gcs.generic_open(path, mode='r') as fd: | |
reader = csv.DictReader(fd) | |
vocab = {'PAD': 0} | |
sentences, max_target = [], -1 | |
for fields in reader: | |
target = maybe_int(fields['target']) | |
if target is None: | |
continue | |
sentence = [] | |
for tok in nltk.word_tokenize(fields['sentence']): | |
ltok = tok.lower() | |
tid = vocab.get(ltok, None) | |
if tid is None: | |
tid = len(vocab) | |
vocab[ltok] = tid | |
sentence.append(tid) | |
if sentence: | |
sentences.append((sentence, target)) | |
max_target = max(max_target, target) | |
return sentences, vocab, max_target | |
def split_and_pad(sentences, splits, pad_value=0): | |
splits = sorted(splits) | |
split_dict = collections.defaultdict(list) | |
discarded = 0 | |
for sentence, target in sentences: | |
lsent = len(sentence) | |
for i in range(0, len(splits)): | |
if lsent < splits[i]: | |
break | |
seqlen = splits[i] | |
if lsent > seqlen: | |
discarded += 1 | |
continue | |
padded_sentence = sentence + [pad_value] * (seqlen - lsent) | |
split_dict[seqlen].append((padded_sentence, target)) | |
return split_dict, discarded | |
def make_batches(split_dict, batch_size): | |
batch_dict = dict() | |
for seqlen, slist in split_dict.items(): | |
batches = [] | |
i = 0 | |
while i + batch_size <= len(slist): | |
batches.append(slist[i:i + batch_size]) | |
i += batch_size | |
if i < len(slist): | |
batch = slist[i:] | |
while len(batch) < batch_size: | |
batch.append(slist[random.randint(0, len(slist) - 1)]) | |
batches.append(batch) | |
batch_dict[seqlen] = batches | |
return batch_dict | |
def to_one_hot(y, n_dims, dtype=torch.float32): | |
scatter_dim = len(y.size()) | |
y_tensor = y.view(*y.size(), -1) | |
zeros = torch.zeros(*y.size(), n_dims, dtype=dtype, device=y.device) | |
return zeros.scatter(scatter_dim, y_tensor, 1) | |
def gen_tensors(batch_dict, target_dims, shuffle=True): | |
slist = [] | |
for seqlen in sorted(batch_dict.keys()): | |
slist += batch_dict[seqlen] | |
if shuffle: | |
random.shuffle(slist) | |
tensors = [] | |
for bseq in slist: | |
sentence_data, target_data = [], [] | |
for sentence, target in bseq: | |
sentence_data.append(sentence) | |
target_data.append(target) | |
sentence_tensor = torch.tensor(sentence_data, dtype=torch.int64) | |
target_tensor = torch.tensor(target_data, dtype=torch.int64) | |
onehot_tensor = to_one_hot(target_tensor, target_dims) | |
tensors.append((sentence_tensor, onehot_tensor)) | |
return tensors | |
class TestClassifier(nn.Module): | |
def __init__(self, | |
vocab_size, | |
embedding_dim, | |
hidden_dim, | |
output_dim, | |
padding_idx=None): | |
super().__init__() | |
self.embedding = nn.Embedding( | |
vocab_size, embedding_dim, padding_idx=padding_idx) | |
self.lstm = XlaLSTM(embedding_dim, hidden_dim, batch_first=True) | |
self.fc = nn.Linear(hidden_dim, output_dim) | |
self.act = nn.Sigmoid() | |
def forward(self, sentence_tensor): | |
embedded = self.embedding(sentence_tensor) | |
output, (hidden, cell) = self.lstm(embedded, sentence_tensor) | |
dense_outputs = self.fc(hidden) | |
return self.act(dense_outputs) | |
def test_model(path, | |
device, | |
splits, | |
batch_size, | |
embed_size=8, | |
per_target_hs_size=256, | |
lr=0.01, | |
momentum=None, | |
epochs=1, | |
log_interval=10): | |
MIN_HS_SIZE = 128 | |
MAX_HS_SIZE = 1024 * 32 | |
sentences, vocab, max_target = read_dataset(path) | |
split_dict, discarded = split_and_pad(sentences, splits) | |
batch_dict = make_batches(split_dict, batch_size) | |
train_tensors = gen_tensors(batch_dict, max_target + 1) | |
hs_size = min(MAX_HS_SIZE, | |
max((max_target + 1) * per_target_hs_size, MIN_HS_SIZE)) | |
model = TestClassifier( | |
len(vocab), embed_size, hs_size, max_target + 1, padding_idx=0) | |
model.to(device) | |
model.train() | |
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) | |
criterion = nn.MSELoss().to(device) | |
for epoch in range(0, epochs): | |
n = 0 | |
for sentence_tensor, target_tensor in train_tensors: | |
sentence_tensor = sentence_tensor.to(device) | |
target_tensor = target_tensor.to(device) | |
optimizer.zero_grad() | |
output = model(sentence_tensor) | |
loss = criterion(output, target_tensor) | |
loss.backward() | |
xm.optimizer_step(optimizer, barrier=True) | |
# print('OUTPUT:\n', torch_xla._XLAC._get_xla_tensors_text([output])) | |
n += 1 | |
if n % log_interval == 0: | |
print('[{}] Loss: {:.4f}'.format(epoch, loss.cpu().item())) | |
torch.manual_seed(11) | |
device = xm.xla_device() | |
cvs_path = 'gs://davide-stg1/lstm_test_data.csv' | |
nltk.download('punkt') | |
test_model( | |
cvs_path, | |
device, (8, 16, 32, 64), | |
batch_size=16, | |
embed_size=128, | |
lr=0.01, | |
momentum=0.9, | |
epochs=10) | |
print(met.metrics_report()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment