Skip to content

Instantly share code, notes, and snippets.

@dlibenzi
Last active April 4, 2020 20:28
Show Gist options
  • Save dlibenzi/d0898a0778588069e62b9f5c5ed60483 to your computer and use it in GitHub Desktop.
Save dlibenzi/d0898a0778588069e62b9f5c5ed60483 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class XlaLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz, batch_first=False, pad_value=0):
super(XlaLSTM, self).__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
if batch_first:
self.batch_dim, self.sequence_dim = 0, 1
else:
self.batch_dim, self.sequence_dim = 1, 0
self.pad_value = pad_value
self.weight_ih = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.weight_hh = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
for p in self.parameters():
if p.data.ndimension() >= 2:
nn.init.xavier_uniform_(p.data)
else:
nn.init.zeros_(p.data)
def sequence_slice(self, t, embed_out, embed_in):
ot = embed_out[:, t, :] if self.sequence_dim == 1 else embed_out[t, :, :]
it = embed_in[:, t] if self.sequence_dim == 1 else embed_in[t, :]
return ot, it
def sizes(self, embed_out):
size = embed_out.size()
return size[self.batch_dim], size[self.sequence_dim]
def forward(self, embed_out, embed_in, init_states=None):
bs, seq_sz = self.sizes(embed_out)
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size, device=embed_out.device),
torch.zeros(bs, self.hidden_size, device=embed_out.device))
else:
h_t, c_t = init_states
hstate = []
HS = self.hidden_size
for t in range(0, seq_sz):
feat, iseq = self.sequence_slice(t, embed_out, embed_in)
gates = feat @ self.weight_ih + h_t @ self.weight_hh + self.bias
i_t, f_t, g_t, o_t = (
torch.sigmoid(gates[:, :HS]),
torch.sigmoid(gates[:, HS:HS * 2]),
torch.tanh(gates[:, HS * 2:HS * 3]),
torch.sigmoid(gates[:, HS * 3:]),
)
fwd = iseq.unsqueeze(1) != self.pad_value
c_t = torch.where(fwd, f_t * c_t + i_t * g_t, c_t)
h_t = torch.where(fwd, o_t * torch.tanh(c_t), h_t)
hstate.append(h_t.unsqueeze(self.sequence_dim))
hstate = torch.cat(hstate, dim=self.sequence_dim)
return hstate, (h_t, c_t)
#### TEST
import random
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
def test(device,
vocab_size,
max_len,
batch_size,
embed_size=6,
hs_size=5,
padding_idx=0):
embed = nn.Embedding(
vocab_size, embed_size, padding_idx=padding_idx).to(device)
lstm = XlaLSTM(embed_size, hs_size, batch_first=True).to(device)
for _ in range(0, 10):
batch_data = []
for b in range(0, batch_size):
seqlen = random.randint(1, max_len)
data = [random.randint(1, vocab_size - 1) for _ in range(0, seqlen)]
data += [0] * (max_len - seqlen)
batch_data.append(data)
in_tensor = torch.tensor(batch_data, dtype=torch.int64, device=device)
embed_output = embed(in_tensor)
output, (ht, ct) = lstm(embed_output, in_tensor)
print('OUTPUT:\n', torch_xla._XLAC._get_xla_tensors_text([output]))
print(output.cpu())
torch.manual_seed(11)
device = xm.xla_device()
test(device, vocab_size=20, max_len=8, batch_size=2)
print(met.metrics_report())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment