Last active
April 4, 2020 20:28
-
-
Save dlibenzi/d0898a0778588069e62b9f5c5ed60483 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class XlaLSTM(nn.Module): | |
def __init__(self, input_sz, hidden_sz, batch_first=False, pad_value=0): | |
super(XlaLSTM, self).__init__() | |
self.input_sz = input_sz | |
self.hidden_size = hidden_sz | |
if batch_first: | |
self.batch_dim, self.sequence_dim = 0, 1 | |
else: | |
self.batch_dim, self.sequence_dim = 1, 0 | |
self.pad_value = pad_value | |
self.weight_ih = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4)) | |
self.weight_hh = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4)) | |
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4)) | |
self.init_weights() | |
def init_weights(self): | |
for p in self.parameters(): | |
if p.data.ndimension() >= 2: | |
nn.init.xavier_uniform_(p.data) | |
else: | |
nn.init.zeros_(p.data) | |
def sequence_slice(self, t, embed_out, embed_in): | |
ot = embed_out[:, t, :] if self.sequence_dim == 1 else embed_out[t, :, :] | |
it = embed_in[:, t] if self.sequence_dim == 1 else embed_in[t, :] | |
return ot, it | |
def sizes(self, embed_out): | |
size = embed_out.size() | |
return size[self.batch_dim], size[self.sequence_dim] | |
def forward(self, embed_out, embed_in, init_states=None): | |
bs, seq_sz = self.sizes(embed_out) | |
if init_states is None: | |
h_t, c_t = (torch.zeros(bs, self.hidden_size, device=embed_out.device), | |
torch.zeros(bs, self.hidden_size, device=embed_out.device)) | |
else: | |
h_t, c_t = init_states | |
hstate = [] | |
HS = self.hidden_size | |
for t in range(0, seq_sz): | |
feat, iseq = self.sequence_slice(t, embed_out, embed_in) | |
gates = feat @ self.weight_ih + h_t @ self.weight_hh + self.bias | |
i_t, f_t, g_t, o_t = ( | |
torch.sigmoid(gates[:, :HS]), | |
torch.sigmoid(gates[:, HS:HS * 2]), | |
torch.tanh(gates[:, HS * 2:HS * 3]), | |
torch.sigmoid(gates[:, HS * 3:]), | |
) | |
fwd = iseq.unsqueeze(1) != self.pad_value | |
c_t = torch.where(fwd, f_t * c_t + i_t * g_t, c_t) | |
h_t = torch.where(fwd, o_t * torch.tanh(c_t), h_t) | |
hstate.append(h_t.unsqueeze(self.sequence_dim)) | |
hstate = torch.cat(hstate, dim=self.sequence_dim) | |
return hstate, (h_t, c_t) | |
#### TEST | |
import random | |
import torch_xla | |
import torch_xla.core.xla_model as xm | |
import torch_xla.debug.metrics as met | |
def test(device, | |
vocab_size, | |
max_len, | |
batch_size, | |
embed_size=6, | |
hs_size=5, | |
padding_idx=0): | |
embed = nn.Embedding( | |
vocab_size, embed_size, padding_idx=padding_idx).to(device) | |
lstm = XlaLSTM(embed_size, hs_size, batch_first=True).to(device) | |
for _ in range(0, 10): | |
batch_data = [] | |
for b in range(0, batch_size): | |
seqlen = random.randint(1, max_len) | |
data = [random.randint(1, vocab_size - 1) for _ in range(0, seqlen)] | |
data += [0] * (max_len - seqlen) | |
batch_data.append(data) | |
in_tensor = torch.tensor(batch_data, dtype=torch.int64, device=device) | |
embed_output = embed(in_tensor) | |
output, (ht, ct) = lstm(embed_output, in_tensor) | |
print('OUTPUT:\n', torch_xla._XLAC._get_xla_tensors_text([output])) | |
print(output.cpu()) | |
torch.manual_seed(11) | |
device = xm.xla_device() | |
test(device, vocab_size=20, max_len=8, batch_size=2) | |
print(met.metrics_report()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment